Dataset, Sampler, DataLoader, DataLoaderIter简单分析
一、Dataset
torch.utils.data.Dataset是一个抽象类,所有其他类的数据集类都是它的子类,所有子类都应该重载len和getitem。
1 2 3 4 5 6 7
| class Dataset(object): def __getitem__(self, index): raise NotImplementedError def __len__(self): raise NotImplementedError def __add__(self, other): return ConcatDataset([self, other])
|
其子类中,torch.utils.data.TensorDataset,是将数据封装成tensor的数据集,每一个样本通过索引张量来获得。
1 2 3 4 5 6 7 8 9 10
| class TensorDataset(Dataset): def __init__(self, *tensor): assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors) self.tensors = tensors
def __getitem__(self, index): return tuple(tensor[index] for tensor in tensors)
def __len__(self): return self.tensors[0].size(0)
|
二、Sampler
torch.utils.data.Sampler 是负责生成Dataset的索引的类。
1 2 3 4 5 6 7
| class Sampler(object): def __init__(self, data_source): pass def __iter__(self): raise NotImplementedError def __len__(self): raise NotImplementedError
|
Sampler有子类SequentialSampler和RandomSampler,定义如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
| class SequentialSampler(Sampler): def __init__(self, data_source): self.data_source = data_source def __iter__(self): return iter(range(len(self.data_source))) def __len__(self): return len(self.data_source)
class RandomSampler(Sampler): def __init__(self, data_source): self.data_source = data_source def __iter__(self): return iter(torch.randperm(len(self.data_source)).long()) def __len__(self): return len(self.data_source)
|
torch.utils.data.BatchSampler是基于Sampler来构造的,用来生成批量索引。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
| class BatchSampler(object): def __init__(self, sampler, batch_size, drop_last): self.sampler = sampler self.batch_size = batch_size self.drop_last = drop_last def __iter__(self): batch = [] for idx in self.sampler: batch.append(idx) if len(batch) == self.batch_size: yield batch batch = [] if len(batch) > 0 and not self.drop_last: yield batch def __len__(self): if self.drop_last: return len(self.sampler) // self.batch_size else: return (len(self.sampler) + self.batch_size - 1) // self.batch_size
|
drop_last的例子:
1 2 3 4
| >>> list(BatchSampler(range(10), batch_size=3, drop_last=False)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list(BatchSampler(range(10), batch_size=3, drop_last=True)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
|
batchsampler的其他例子
1 2 3 4 5
| from torch.utils.data.sampler import BatchSampler [x for x in BatchSampler(range(10), batch_size=3, drop_last=False)] Out[9]: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] [x for x in BatchSampler(RandomSampler(range(10)), batch_size=3, drop_last=False)] Out[15]: [[1, 3, 7], [9, 2, 0], [5, 4, 6], [8]]
|
三、DataLoader
torch.utils.data.DataLoader 负责加载数据,支持多进程。
其接口定义如下:
1 2 3 4 5
| DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, *, prefetch_factor=2, persistent_workers=False)
|
dataset是加载的数据集。
batch_size是每个batch加载数据的多少。
shuffle设置为True的时候,调用RandomSampler进行随机索引。
sampler是从数据集中提取样本的策略,如果指定了sampler,则shuffle参数必须为False。
batch_sampler与sampler类似,每次返回一个批次的索引,如果指定了batch_sampler,则batch_size和shuffle必须与之相符合。
num_workers是数据加载的子进程数。
collate_fn是一个callable的函数,将Map-style
dataset取出的batch_size个数据(tuple类型,每个tuple长度为2,其中第一个是数据,第二个是标签)整合成一个list,这个list的长度为2,一个是batch_size个数据组成的FloatTensor,一个是batch_size个标签组成的longTensor。
pin_memory如果为True,则 DataLoader 在将张量返回之前将其复制到 CUDA
固定的内存中。
drop_last,如果最后一个batch没满,是否要丢掉。
timeout,如果为正,则为从 worker 收集 batch
的超时值,应始终为非负数,超过这个时间还没读取到数据的话就会报错
worker_init_fn 是callable的函数,如果不为 None,它将会被每个 worker
子进程调用,以 worker id ([0, num_workers - 1] 内的整形) 为输入。
prefetch_factor是每个 worker 提前加载 的 sample 数量,默认为2。
persistent_workers 如果为 True,dataloader 将不会终止 worker
进程,直到 dataset 迭代完成。
collate_fn
当collate_fn作用于数据样本列表,将输入样本整理为一个
batch时(不是只处理一个数据时),通常做以下3件事:
- 添加新的维度(第一维)
- 自动将numpy数组和python数值转换为tensor
- 保留数据结构
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
| class DataLoader(object): def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): self.dataset = dataset self.batch_size = batch_size self.num_workers = num_workers self.collate_fn = collate_fn self.pin_memory = pin_memory self.drop_last = drop_last self.timeout = timeout self.worker_init_fn = worker_init_fn if timeout < 0: raise ValueError('timeout option should be non-negative') if batch_sampler is not None: if batch_size > 1 or shuffle or sampler is not None or drop_last: raise ValueError('batch_sampler is mutually exclusive with ' 'batch_size, shuffle, sampler, and drop_last') if sampler is not None and shuffle: raise ValueError('sampler is mutually exclusive with shuffle') if self.num_workers < 0: raise ValueError('num_workers cannot be negative; ' 'use num_workers=0 to disable multiprocessing.') if batch_sampler is None: if sampler is None: if shuffle: sampler = RandomSampler(dataset) else: sampler = SequentialSampler(dataset) batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.sampler = sampler self.batch_sampler = batch_sampler def __iter__(self): return DataLoaderIter(self) def __len__(self): return len(self.batch_sampler)
|
三、DataLoaderIter
在调用iter(DataLoader)的时候,返回了DataLoaderIter。
(只看了单线程,多线程还不会)
定义如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
| class DataLoaderIter(object): "Iterates once over the DataLoader's dataset, as specified by the sampler" def __init__(self, loader): self.dataset = loader.dataset self.collate_fn = loader.collate_fn self.batch_sampler = loader.batch_sampler self.num_workers = loader.num_workers self.pin_memory = loader.pin_memory and torch.cuda.is_available() self.timeout = loader.timeout self.done_event = threading.Event() self.sample_iter = iter(self.batch_sampler) if self.num_workers > 0: self.worker_init_fn = loader.worker_init_fn self.index_queue = multiprocessing.SimpleQueue() self.worker_result_queue = multiprocessing.SimpleQueue() self.batches_outstanding = 0 self.worker_pids_set = False self.shutdown = False self.send_idx = 0 self.rcvd_idx = 0 self.reorder_dict = {} base_seed = torch.LongTensor(1).random_()[0] self.workers = [ multiprocessing.Process( target=_worker_loop, args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn, base_seed + i, self.worker_init_fn, i)) for i in range(self.num_workers)] if self.pin_memory or self.timeout > 0: self.data_queue = queue.Queue() self.worker_manager_thread = threading.Thread( target=_worker_manager_loop, args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory, torch.cuda.current_device())) self.worker_manager_thread.daemon = True self.worker_manager_thread.start() else: self.data_queue = self.worker_result_queue for w in self.workers: w.daemon = True w.start() _update_worker_pids(id(self), tuple(w.pid for w in self.workers)) _set_SIGCHLD_handler() self.worker_pids_set = True for _ in range(2 * self.num_workers): self._put_indices() def __next__(self): if self.num_workers == 0: indices = next(self.sample_iter) batch = self.collate_fn([self.dataset[i] for i in indices]) if self.pin_memory: batch = pin_memory_batch(batch) return batch if self.rcvd_idx in self.reorder_dict: batch = self.reorder_dict.pop(self.rcvd_idx) return self._process_next_batch(batch) if self.batches_outstanding == 0: self._shutdown_workers() raise StopIteration while True: assert (not self.shutdown and self.batches_outstanding > 0) idx, batch = self._get_batch() self.batches_outstanding -= 1 if idx != self.rcvd_idx: self.reorder_dict[idx] = batch continue return self._process_next_batch(batch) def _get_batch(self): if self.timeout > 0: try: return self.data_queue.get(True, self.timeout) except queue.Empty: raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout)) else: return self.data_queue.get() def _process_next_batch(self, batch): self.rcvd_idx += 1 self._put_indices() if isinstance(batch, ExceptionWrapper): raise batch.exc_type(batch.exc_msg) return batch def _put_indices(self): assert self.batches_outstanding < 2 * self.num_workers indices = next(self.sample_iter, None) if indices is None: return self.index_queue.put((self.send_idx, indices)) self.batches_outstanding += 1 self.send_idx += 1
|