Dataset, Sampler, DataLoader, DataLoaderIter简单分析
一、Dataset
torch.utils.data.Dataset是一个抽象类,所有其他类的数据集类都是它的子类,所有子类都应该重载len和getitem。
1 2 3 4 5 6 7
   | class Dataset(object):     def __getitem__(self, index):         raise NotImplementedError     def __len__(self):         raise NotImplementedError     def __add__(self, other):         return ConcatDataset([self, other])
   | 
 
其子类中,torch.utils.data.TensorDataset,是将数据封装成tensor的数据集,每一个样本通过索引张量来获得。
1 2 3 4 5 6 7 8 9 10
   | class TensorDataset(Dataset):     def __init__(self, *tensor):         assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)          self.tensors = tensors 
      def __getitem__(self, index):         return tuple(tensor[index] for tensor in tensors) 
      def __len__(self):         return self.tensors[0].size(0)
   | 
 
二、Sampler
torch.utils.data.Sampler 是负责生成Dataset的索引的类。
1 2 3 4 5 6 7
   | class Sampler(object):     def __init__(self, data_source):         pass     def __iter__(self):          raise NotImplementedError     def __len__(self):          raise NotImplementedError
   | 
 
Sampler有子类SequentialSampler和RandomSampler,定义如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
   |  class SequentialSampler(Sampler):     def __init__(self, data_source):         self.data_source = data_source     def __iter__(self):         return iter(range(len(self.data_source)))     def __len__(self):         return len(self.data_source)
  class RandomSampler(Sampler):     def __init__(self, data_source):         self.data_source = data_source     def __iter__(self):         return iter(torch.randperm(len(self.data_source)).long())      def __len__(self):         return len(self.data_source)
 
  | 
 
torch.utils.data.BatchSampler是基于Sampler来构造的,用来生成批量索引。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
   | class BatchSampler(object):     def __init__(self, sampler, batch_size, drop_last):          self.sampler = sampler           self.batch_size = batch_size         self.drop_last = drop_last     def __iter__(self):         batch = []         for idx in self.sampler:             batch.append(idx)              if len(batch) == self.batch_size:                 yield batch                 batch = []         if len(batch) > 0 and not self.drop_last:             yield batch     def __len__(self):         if self.drop_last:             return len(self.sampler) // self.batch_size         else:             return (len(self.sampler) + self.batch_size - 1) // self.batch_size
   | 
 
drop_last的例子:
1 2 3 4
   | >>> list(BatchSampler(range(10), batch_size=3, drop_last=False))         [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list(BatchSampler(range(10), batch_size=3, drop_last=True))         [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
   | 
 
batchsampler的其他例子
1 2 3 4 5
   | from torch.utils.data.sampler import BatchSampler [x for x in BatchSampler(range(10), batch_size=3, drop_last=False)] Out[9]: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] [x for x in BatchSampler(RandomSampler(range(10)), batch_size=3, drop_last=False)] Out[15]: [[1, 3, 7], [9, 2, 0], [5, 4, 6], [8]]
   | 
 
三、DataLoader
torch.utils.data.DataLoader 负责加载数据,支持多进程。
其接口定义如下:
1 2 3 4 5
   | DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,            batch_sampler=None, num_workers=0, collate_fn=None,            pin_memory=False, drop_last=False, timeout=0,            worker_init_fn=None, *, prefetch_factor=2,            persistent_workers=False)
   | 
 
dataset是加载的数据集。
batch_size是每个batch加载数据的多少。
shuffle设置为True的时候,调用RandomSampler进行随机索引。
sampler是从数据集中提取样本的策略,如果指定了sampler,则shuffle参数必须为False。
batch_sampler与sampler类似,每次返回一个批次的索引,如果指定了batch_sampler,则batch_size和shuffle必须与之相符合。
num_workers是数据加载的子进程数。
collate_fn是一个callable的函数,将Map-style
dataset取出的batch_size个数据(tuple类型,每个tuple长度为2,其中第一个是数据,第二个是标签)整合成一个list,这个list的长度为2,一个是batch_size个数据组成的FloatTensor,一个是batch_size个标签组成的longTensor。
pin_memory如果为True,则 DataLoader 在将张量返回之前将其复制到 CUDA
固定的内存中。
drop_last,如果最后一个batch没满,是否要丢掉。
timeout,如果为正,则为从 worker 收集 batch
的超时值,应始终为非负数,超过这个时间还没读取到数据的话就会报错
worker_init_fn 是callable的函数,如果不为 None,它将会被每个 worker
子进程调用,以 worker id ([0, num_workers - 1] 内的整形) 为输入。
prefetch_factor是每个 worker 提前加载 的 sample 数量,默认为2。
persistent_workers 如果为 True,dataloader 将不会终止 worker
进程,直到 dataset 迭代完成。
collate_fn
当collate_fn作用于数据样本列表,将输入样本整理为一个
batch时(不是只处理一个数据时),通常做以下3件事:
- 添加新的维度(第一维)
 
- 自动将numpy数组和python数值转换为tensor
 
- 保留数据结构
 
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
   | class DataLoader(object):     def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,                  num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,                  timeout=0, worker_init_fn=None):         self.dataset = dataset         self.batch_size = batch_size         self.num_workers = num_workers         self.collate_fn = collate_fn         self.pin_memory = pin_memory         self.drop_last = drop_last         self.timeout = timeout         self.worker_init_fn = worker_init_fn         if timeout < 0:             raise ValueError('timeout option should be non-negative')                  if batch_sampler is not None:             if batch_size > 1 or shuffle or sampler is not None or drop_last:                 raise ValueError('batch_sampler is mutually exclusive with '                                  'batch_size, shuffle, sampler, and drop_last')         if sampler is not None and shuffle:             raise ValueError('sampler is mutually exclusive with shuffle')         if self.num_workers < 0:             raise ValueError('num_workers cannot be negative; '                              'use num_workers=0 to disable multiprocessing.')                  if batch_sampler is None:                          if sampler is None:                 if shuffle:                     sampler = RandomSampler(dataset)                 else:                     sampler = SequentialSampler(dataset)             batch_sampler = BatchSampler(sampler, batch_size, drop_last)                  self.sampler = sampler         self.batch_sampler = batch_sampler     def __iter__(self):                  return DataLoaderIter(self)     def __len__(self):         return len(self.batch_sampler)
   | 
 
三、DataLoaderIter
在调用iter(DataLoader)的时候,返回了DataLoaderIter。
(只看了单线程,多线程还不会)
定义如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
   | class DataLoaderIter(object):     "Iterates once over the DataLoader's dataset, as specified by the sampler"       def __init__(self, loader):         self.dataset = loader.dataset         self.collate_fn = loader.collate_fn         self.batch_sampler = loader.batch_sampler         self.num_workers = loader.num_workers         self.pin_memory = loader.pin_memory and torch.cuda.is_available()         self.timeout = loader.timeout         self.done_event = threading.Event()           self.sample_iter = iter(self.batch_sampler)            if self.num_workers > 0:              self.worker_init_fn = loader.worker_init_fn             self.index_queue = multiprocessing.SimpleQueue()             self.worker_result_queue = multiprocessing.SimpleQueue()             self.batches_outstanding = 0             self.worker_pids_set = False             self.shutdown = False             self.send_idx = 0             self.rcvd_idx = 0             self.reorder_dict = {}               base_seed = torch.LongTensor(1).random_()[0]             self.workers = [                 multiprocessing.Process(                     target=_worker_loop,                     args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn,                           base_seed + i, self.worker_init_fn, i))                 for i in range(self.num_workers)]               if self.pin_memory or self.timeout > 0:                 self.data_queue = queue.Queue()                 self.worker_manager_thread = threading.Thread(                     target=_worker_manager_loop,                     args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,                           torch.cuda.current_device()))                 self.worker_manager_thread.daemon = True                 self.worker_manager_thread.start()             else:                 self.data_queue = self.worker_result_queue               for w in self.workers:                 w.daemon = True                   w.start()               _update_worker_pids(id(self), tuple(w.pid for w in self.workers))             _set_SIGCHLD_handler()             self.worker_pids_set = True               for _ in range(2 * self.num_workers):                 self._put_indices()                      def __next__(self):         if self.num_workers == 0:               indices = next(self.sample_iter)              batch = self.collate_fn([self.dataset[i] for i in indices])              if self.pin_memory:                 batch = pin_memory_batch(batch)             return batch                     if self.rcvd_idx in self.reorder_dict:             batch = self.reorder_dict.pop(self.rcvd_idx)             return self._process_next_batch(batch)           if self.batches_outstanding == 0:             self._shutdown_workers()             raise StopIteration           while True:             assert (not self.shutdown and self.batches_outstanding > 0)             idx, batch = self._get_batch()             self.batches_outstanding -= 1             if idx != self.rcvd_idx:                 self.reorder_dict[idx] = batch                 continue             return self._process_next_batch(batch)                   def _get_batch(self):         if self.timeout > 0:             try:                 return self.data_queue.get(True, self.timeout)             except queue.Empty:                 raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))         else:             return self.data_queue.get()                     def _process_next_batch(self, batch):         self.rcvd_idx += 1         self._put_indices()         if isinstance(batch, ExceptionWrapper):             raise batch.exc_type(batch.exc_msg)         return batch         def _put_indices(self):         assert self.batches_outstanding < 2 * self.num_workers         indices = next(self.sample_iter, None)         if indices is None:             return         self.index_queue.put((self.send_idx, indices))         self.batches_outstanding += 1         self.send_idx += 1    
   |