0%

Pytorch数据载入学习

Dataset, Sampler, DataLoader, DataLoaderIter简单分析

一、Dataset

torch.utils.data.Dataset是一个抽象类,所有其他类的数据集类都是它的子类,所有子类都应该重载len和getitem。

1
2
3
4
5
6
7
class Dataset(object):
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])

其子类中,torch.utils.data.TensorDataset,是将数据封装成tensor的数据集,每一个样本通过索引张量来获得。

1
2
3
4
5
6
7
8
9
10
class TensorDataset(Dataset):
def __init__(self, *tensor):
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors) # 要保证每一个张量的第一维大小相同(即数量)
self.tensors = tensors # 把数据封装成一个张量

def __getitem__(self, index):
return tuple(tensor[index] for tensor in tensors) # 返回张量在第index层的切片,即所有数据的第index个组合成的张量

def __len__(self):
return self.tensors[0].size(0)

二、Sampler

torch.utils.data.Sampler 是负责生成Dataset的索引的类。

1
2
3
4
5
6
7
class Sampler(object):
def __init__(self, data_source):
pass
def __iter__(self): # 返回一个迭代器,每个都是“一系列”数据的索引
raise NotImplementedError
def __len__(self): # 表示返回的迭代器的长度
raise NotImplementedError

Sampler有子类SequentialSampler和RandomSampler,定义如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 顺序采样
class SequentialSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
# 随机采样
class RandomSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(torch.randperm(len(self.data_source)).long()) # 随机生成一个排列
def __len__(self):
return len(self.data_source)

torch.utils.data.BatchSampler是基于Sampler来构造的,用来生成批量索引。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class BatchSampler(object):
def __init__(self, sampler, batch_size, drop_last): # 参数sampler, batch_size(每一批次大小), drop_last(如果最后一个批次没满,是否丢掉最后一批次)
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx) # 根据sampler的索引,将数据装入一个batch中
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size

drop_last的例子:

1
2
3
4
>>> list(BatchSampler(range(10), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(range(10), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]

batchsampler的其他例子

1
2
3
4
5
from torch.utils.data.sampler import BatchSampler
[x for x in BatchSampler(range(10), batch_size=3, drop_last=False)]
Out[9]: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
[x for x in BatchSampler(RandomSampler(range(10)), batch_size=3, drop_last=False)]
Out[15]: [[1, 3, 7], [9, 2, 0], [5, 4, 6], [8]]

三、DataLoader

torch.utils.data.DataLoader 负责加载数据,支持多进程。

其接口定义如下:

1
2
3
4
5
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)

dataset是加载的数据集。

batch_size是每个batch加载数据的多少。

shuffle设置为True的时候,调用RandomSampler进行随机索引。

sampler是从数据集中提取样本的策略,如果指定了sampler,则shuffle参数必须为False。

batch_sampler与sampler类似,每次返回一个批次的索引,如果指定了batch_sampler,则batch_size和shuffle必须与之相符合。

num_workers是数据加载的子进程数。

collate_fn是一个callable的函数,将Map-style dataset取出的batch_size个数据(tuple类型,每个tuple长度为2,其中第一个是数据,第二个是标签)整合成一个list,这个list的长度为2,一个是batch_size个数据组成的FloatTensor,一个是batch_size个标签组成的longTensor。

pin_memory如果为True,则 DataLoader 在将张量返回之前将其复制到 CUDA 固定的内存中。

drop_last,如果最后一个batch没满,是否要丢掉。

timeout,如果为正,则为从 worker 收集 batch 的超时值,应始终为非负数,超过这个时间还没读取到数据的话就会报错

worker_init_fn 是callable的函数,如果不为 None,它将会被每个 worker 子进程调用,以 worker id ([0, num_workers - 1] 内的整形) 为输入。

prefetch_factor是每个 worker 提前加载 的 sample 数量,默认为2。

persistent_workers 如果为 True,dataloader 将不会终止 worker 进程,直到 dataset 迭代完成。

collate_fn

当collate_fn作用于数据样本列表,将输入样本整理为一个 batch时(不是只处理一个数据时),通常做以下3件事:

  1. 添加新的维度(第一维)
  2. 自动将numpy数组和python数值转换为tensor
  3. 保留数据结构
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class DataLoader(object):
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
self.collate_fn = collate_fn
self.pin_memory = pin_memory
self.drop_last = drop_last
self.timeout = timeout
self.worker_init_fn = worker_init_fn
if timeout < 0:
raise ValueError('timeout option should be non-negative')
# 检测是否存在参数冲突
if batch_sampler is not None:
if batch_size > 1 or shuffle or sampler is not None or drop_last:
raise ValueError('batch_sampler is mutually exclusive with '
'batch_size, shuffle, sampler, and drop_last')
if sampler is not None and shuffle:
raise ValueError('sampler is mutually exclusive with shuffle')
if self.num_workers < 0:
raise ValueError('num_workers cannot be negative; '
'use num_workers=0 to disable multiprocessing.')
# 在此处会强行指定一个 BatchSampler
if batch_sampler is None:
# 在此处会强行指定一个 Sampler
if sampler is None:
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
# 使用自定义的采样器和批采样器
self.sampler = sampler
self.batch_sampler = batch_sampler
def __iter__(self):
# 调用Pytorch的多线程迭代器加载数据
return DataLoaderIter(self)
def __len__(self):
return len(self.batch_sampler)

三、DataLoaderIter

在调用iter(DataLoader)的时候,返回了DataLoaderIter。

(只看了单线程,多线程还不会)

定义如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
class DataLoaderIter(object):
"Iterates once over the DataLoader's dataset, as specified by the sampler"

def __init__(self, loader):
self.dataset = loader.dataset
self.collate_fn = loader.collate_fn
self.batch_sampler = loader.batch_sampler
self.num_workers = loader.num_workers
self.pin_memory = loader.pin_memory and torch.cuda.is_available()
self.timeout = loader.timeout
self.done_event = threading.Event()

self.sample_iter = iter(self.batch_sampler) # DataLoaderIter比DataLoader多了sample_iter

if self.num_workers > 0: # 后面就都是处理多线程的了
self.worker_init_fn = loader.worker_init_fn
self.index_queue = multiprocessing.SimpleQueue()
self.worker_result_queue = multiprocessing.SimpleQueue()
self.batches_outstanding = 0
self.worker_pids_set = False
self.shutdown = False
self.send_idx = 0
self.rcvd_idx = 0
self.reorder_dict = {}

base_seed = torch.LongTensor(1).random_()[0]
self.workers = [
multiprocessing.Process(
target=_worker_loop,
args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn,
base_seed + i, self.worker_init_fn, i))
for i in range(self.num_workers)]

if self.pin_memory or self.timeout > 0:
self.data_queue = queue.Queue()
self.worker_manager_thread = threading.Thread(
target=_worker_manager_loop,
args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
torch.cuda.current_device()))
self.worker_manager_thread.daemon = True
self.worker_manager_thread.start()
else:
self.data_queue = self.worker_result_queue

for w in self.workers:
w.daemon = True # ensure that the worker exits on process exit
w.start()

_update_worker_pids(id(self), tuple(w.pid for w in self.workers))
_set_SIGCHLD_handler()
self.worker_pids_set = True

for _ in range(2 * self.num_workers):
self._put_indices()

def __next__(self):
if self.num_workers == 0: # 如果是单线程
indices = next(self.sample_iter) # 生成一个批次数据的索引
batch = self.collate_fn([self.dataset[i] for i in indices]) # 用collate_fn将这些数据打包成一个长度为2的list
if self.pin_memory:
batch = pin_memory_batch(batch)
return batch # 返回这个批次数据

# 后面都是多线程了
if self.rcvd_idx in self.reorder_dict:
batch = self.reorder_dict.pop(self.rcvd_idx)
return self._process_next_batch(batch)

if self.batches_outstanding == 0:
self._shutdown_workers()
raise StopIteration

while True:
assert (not self.shutdown and self.batches_outstanding > 0)
idx, batch = self._get_batch()
self.batches_outstanding -= 1
if idx != self.rcvd_idx:
self.reorder_dict[idx] = batch
continue
return self._process_next_batch(batch)

def _get_batch(self):
if self.timeout > 0:
try:
return self.data_queue.get(True, self.timeout)
except queue.Empty:
raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
else:
return self.data_queue.get()

def _process_next_batch(self, batch):
self.rcvd_idx += 1
self._put_indices()
if isinstance(batch, ExceptionWrapper):
raise batch.exc_type(batch.exc_msg)
return batch

def _put_indices(self):
assert self.batches_outstanding < 2 * self.num_workers
indices = next(self.sample_iter, None)
if indices is None:
return
self.index_queue.put((self.send_idx, indices))
self.batches_outstanding += 1
self.send_idx += 1
-------------本文结束感谢阅读-------------