# -*- coding: utf-8 -*-
from collections import namedtuple
import torch
import torch.distributed as dist
from ..utils.alg import kmeans
[docs]class Dataset(torch.utils.data.Dataset):
r"""
Dataset that is compatible with :class:`torch.utils.data.Dataset`.
This serves as a wrapper for manipulating all data fields
with the operating behaviours defined in :class:`Transform`.
The data fields of all the instantiated sentences can be accessed as an attribute of the dataset.
Args:
transform (Transform):
An instance of :class:`Transform` and its derivations.
The instance holds a series of loading and processing behaviours with regard to the specfic data format.
data (list[list] or str):
A list of list of strings or a filename.
This will be passed into :meth:`transform.load`.
kwargs (dict):
Keyword arguments that will be passed into :meth:`transform.load` together with `data`
to control the loading behaviour.
Attributes:
transform (Transform):
An instance of :class:`Transform`.
sentences (list[Sentence]):
A list of sentences loaded from the data.
Each sentence includes fields obeying the data format defined in ``transform``.
"""
def __init__(self, transform, data, **kwargs):
super(Dataset, self).__init__()
self.transform = transform
self.sentences = transform.load(data, **kwargs)
def __repr__(self):
s = f"{self.__class__.__name__}("
s += f"n_sentences={len(self.sentences)}"
if hasattr(self, 'loader'):
s += f", n_batches={len(self.loader)}"
if hasattr(self, 'buckets'):
s += f", n_buckets={len(self.buckets)}"
s += ")"
return s
def __len__(self):
return len(self.sentences)
def __getitem__(self, index):
if not hasattr(self, 'fields'):
raise RuntimeError("The fields are not numericalized. Please build the dataset first.")
for d in self.fields.values():
yield d[index]
def __getattr__(self, name):
if name in self.__dict__:
return self.__dict__[name]
return [getattr(sentence, name) for sentence in self.sentences]
def __setattr__(self, name, value):
if 'sentences' in self.__dict__ and name in self.sentences[0]:
# restore the order of sequences in the buckets
indices = torch.tensor([i
for bucket in self.buckets.values()
for i in bucket]).argsort()
for index, sentence in zip(indices, self.sentences):
setattr(sentence, name, value[index])
else:
self.__dict__[name] = value
def __getstate__(self):
# only pickle the Transform object and sentences
return {'transform': self.transform, 'sentences': self.sentences}
def __setstate__(self, state):
self.__dict__.update(state)
def collate_fn(self, batch):
return {f: d for f, d in zip(self.fields.keys(), zip(*batch))}
def build(self, batch_size, n_buckets=1, shuffle=False, distributed=False):
# numericalize all fields
self.fields = self.transform(self.sentences)
# NOTE: the final bucket count is roughly equal to n_buckets
self.lengths = [len(i) for i in self.fields[next(iter(self.fields))]]
self.buckets = dict(zip(*kmeans(self.lengths, n_buckets)))
self.loader = DataLoader(dataset=self,
batch_sampler=Sampler(buckets=self.buckets,
batch_size=batch_size,
shuffle=shuffle,
distributed=distributed),
collate_fn=self.collate_fn)
class DataLoader(torch.utils.data.DataLoader):
r"""
DataLoader, matching with :class:`Dataset`.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __iter__(self):
for batch in super().__iter__():
yield namedtuple('Batch', [f.name for f in batch.keys()])(*[f.compose(d) for f, d in batch.items()])
class Sampler(torch.utils.data.Sampler):
r"""
Sampler that supports for bucketization and token-level batchification.
Args:
buckets (dict):
A dict that maps each centroid to indices of clustered sentences.
The centroid corresponds to the average length of all sentences in the bucket.
batch_size (int):
Token-level batch size. The resulting batch contains roughly the same number of tokens as ``batch_size``.
shuffle (bool):
If ``True``, the sampler will shuffle both buckets and samples in each bucket. Default: ``False``.
distributed (bool):
If ``True``, the sampler will be used in conjunction with :class:`torch.nn.parallel.DistributedDataParallel`
that restricts data loading to a subset of the dataset.
Default: ``False``.
"""
def __init__(self, buckets, batch_size, shuffle=False, distributed=False):
self.batch_size = batch_size
self.shuffle = shuffle
self.sizes, self.buckets = zip(*[(size, bucket) for size, bucket in buckets.items()])
# number of chunks in each bucket, clipped by range [1, len(bucket)]
self.chunks = [min(len(bucket), max(round(size * len(bucket) / batch_size), 1))
for size, bucket in zip(self.sizes, self.buckets)]
self.rank = dist.get_rank() if distributed else 0
self.replicas = dist.get_world_size() if distributed else 1
self.samples = sum(self.chunks) // self.replicas
self.epoch = 0
def __iter__(self):
g = torch.Generator()
g.manual_seed(self.epoch)
range_fn = torch.arange
# if `shuffle=True`, shuffle both the buckets and samples in each bucket
# for distributed training, make sure each process generates the same random sequence at each epoch
if self.shuffle:
def range_fn(x):
return torch.randperm(x, generator=g)
total, count = 0, 0
# TODO: more elegant way to deal with uneven data, which we directly discard right now
for i in range_fn(len(self.buckets)).tolist():
split_sizes = [(len(self.buckets[i]) - j - 1) // self.chunks[i] + 1
for j in range(self.chunks[i])]
# DON'T use `torch.chunk` which may return wrong number of chunks
for batch in range_fn(len(self.buckets[i])).split(split_sizes):
if count == self.samples:
break
if total % self.replicas == self.rank:
count += 1
yield [self.buckets[i][j] for j in batch.tolist()]
total += 1
self.epoch += 1
def __len__(self):
return self.samples