dataset.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. import bisect
  2. import warnings
  3. from typing import (
  4. Generic,
  5. Iterable,
  6. Iterator,
  7. List,
  8. Optional,
  9. Sequence,
  10. Tuple,
  11. TypeVar,
  12. )
  13. # No 'default_generator' in torch/__init__.pyi
  14. from torch import default_generator, randperm
  15. from torch._utils import _accumulate
  16. from ... import Generator, Tensor
  17. __all__ = [
  18. "Dataset",
  19. "IterableDataset",
  20. "TensorDataset",
  21. "ConcatDataset",
  22. "ChainDataset",
  23. "Subset",
  24. "random_split",
  25. ]
  26. T_co = TypeVar('T_co', covariant=True)
  27. T = TypeVar('T')
  28. class Dataset(Generic[T_co]):
  29. r"""An abstract class representing a :class:`Dataset`.
  30. All datasets that represent a map from keys to data samples should subclass
  31. it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
  32. data sample for a given key. Subclasses could also optionally overwrite
  33. :meth:`__len__`, which is expected to return the size of the dataset by many
  34. :class:`~torch.utils.data.Sampler` implementations and the default options
  35. of :class:`~torch.utils.data.DataLoader`.
  36. .. note::
  37. :class:`~torch.utils.data.DataLoader` by default constructs a index
  38. sampler that yields integral indices. To make it work with a map-style
  39. dataset with non-integral indices/keys, a custom sampler must be provided.
  40. """
  41. def __getitem__(self, index) -> T_co:
  42. raise NotImplementedError
  43. def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
  44. return ConcatDataset([self, other])
  45. # No `def __len__(self)` default?
  46. # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
  47. # in pytorch/torch/utils/data/sampler.py
  48. class IterableDataset(Dataset[T_co]):
  49. r"""An iterable Dataset.
  50. All datasets that represent an iterable of data samples should subclass it.
  51. Such form of datasets is particularly useful when data come from a stream.
  52. All subclasses should overwrite :meth:`__iter__`, which would return an
  53. iterator of samples in this dataset.
  54. When a subclass is used with :class:`~torch.utils.data.DataLoader`, each
  55. item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader`
  56. iterator. When :attr:`num_workers > 0`, each worker process will have a
  57. different copy of the dataset object, so it is often desired to configure
  58. each copy independently to avoid having duplicate data returned from the
  59. workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker
  60. process, returns information about the worker. It can be used in either the
  61. dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's
  62. :attr:`worker_init_fn` option to modify each copy's behavior.
  63. Example 1: splitting workload across all workers in :meth:`__iter__`::
  64. >>> class MyIterableDataset(torch.utils.data.IterableDataset):
  65. ... def __init__(self, start, end):
  66. ... super(MyIterableDataset).__init__()
  67. ... assert end > start, "this example code only works with end >= start"
  68. ... self.start = start
  69. ... self.end = end
  70. ...
  71. ... def __iter__(self):
  72. ... worker_info = torch.utils.data.get_worker_info()
  73. ... if worker_info is None: # single-process data loading, return the full iterator
  74. ... iter_start = self.start
  75. ... iter_end = self.end
  76. ... else: # in a worker process
  77. ... # split workload
  78. ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
  79. ... worker_id = worker_info.id
  80. ... iter_start = self.start + worker_id * per_worker
  81. ... iter_end = min(iter_start + per_worker, self.end)
  82. ... return iter(range(iter_start, iter_end))
  83. ...
  84. >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
  85. >>> ds = MyIterableDataset(start=3, end=7)
  86. >>> # Single-process loading
  87. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
  88. [3, 4, 5, 6]
  89. >>> # Mult-process loading with two worker processes
  90. >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
  91. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
  92. [3, 5, 4, 6]
  93. >>> # With even more workers
  94. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=20)))
  95. [3, 4, 5, 6]
  96. Example 2: splitting workload across all workers using :attr:`worker_init_fn`::
  97. >>> class MyIterableDataset(torch.utils.data.IterableDataset):
  98. ... def __init__(self, start, end):
  99. ... super(MyIterableDataset).__init__()
  100. ... assert end > start, "this example code only works with end >= start"
  101. ... self.start = start
  102. ... self.end = end
  103. ...
  104. ... def __iter__(self):
  105. ... return iter(range(self.start, self.end))
  106. ...
  107. >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
  108. >>> ds = MyIterableDataset(start=3, end=7)
  109. >>> # Single-process loading
  110. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
  111. [3, 4, 5, 6]
  112. >>>
  113. >>> # Directly doing multi-process loading yields duplicate data
  114. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
  115. [3, 3, 4, 4, 5, 5, 6, 6]
  116. >>> # Define a `worker_init_fn` that configures each dataset copy differently
  117. >>> def worker_init_fn(worker_id):
  118. ... worker_info = torch.utils.data.get_worker_info()
  119. ... dataset = worker_info.dataset # the dataset copy in this worker process
  120. ... overall_start = dataset.start
  121. ... overall_end = dataset.end
  122. ... # configure the dataset to only process the split workload
  123. ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
  124. ... worker_id = worker_info.id
  125. ... dataset.start = overall_start + worker_id * per_worker
  126. ... dataset.end = min(dataset.start + per_worker, overall_end)
  127. ...
  128. >>> # Mult-process loading with the custom `worker_init_fn`
  129. >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
  130. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
  131. [3, 5, 4, 6]
  132. >>> # With even more workers
  133. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))
  134. [3, 4, 5, 6]
  135. """
  136. def __iter__(self) -> Iterator[T_co]:
  137. raise NotImplementedError
  138. def __add__(self, other: Dataset[T_co]):
  139. return ChainDataset([self, other])
  140. # No `def __len__(self)` default? Subclasses raise `TypeError` when needed.
  141. # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
  142. class TensorDataset(Dataset[Tuple[Tensor, ...]]):
  143. r"""Dataset wrapping tensors.
  144. Each sample will be retrieved by indexing tensors along the first dimension.
  145. Args:
  146. *tensors (Tensor): tensors that have the same size of the first dimension.
  147. """
  148. tensors: Tuple[Tensor, ...]
  149. def __init__(self, *tensors: Tensor) -> None:
  150. assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors"
  151. self.tensors = tensors
  152. def __getitem__(self, index):
  153. return tuple(tensor[index] for tensor in self.tensors)
  154. def __len__(self):
  155. return self.tensors[0].size(0)
  156. class ConcatDataset(Dataset[T_co]):
  157. r"""Dataset as a concatenation of multiple datasets.
  158. This class is useful to assemble different existing datasets.
  159. Args:
  160. datasets (sequence): List of datasets to be concatenated
  161. """
  162. datasets: List[Dataset[T_co]]
  163. cumulative_sizes: List[int]
  164. @staticmethod
  165. def cumsum(sequence):
  166. r, s = [], 0
  167. for e in sequence:
  168. l = len(e)
  169. r.append(l + s)
  170. s += l
  171. return r
  172. def __init__(self, datasets: Iterable[Dataset]) -> None:
  173. super(ConcatDataset, self).__init__()
  174. self.datasets = list(datasets)
  175. assert len(self.datasets) > 0, 'datasets should not be an empty iterable' # type: ignore[arg-type]
  176. for d in self.datasets:
  177. assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
  178. self.cumulative_sizes = self.cumsum(self.datasets)
  179. def __len__(self):
  180. return self.cumulative_sizes[-1]
  181. def __getitem__(self, idx):
  182. if idx < 0:
  183. if -idx > len(self):
  184. raise ValueError("absolute value of index should not exceed dataset length")
  185. idx = len(self) + idx
  186. dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
  187. if dataset_idx == 0:
  188. sample_idx = idx
  189. else:
  190. sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
  191. return self.datasets[dataset_idx][sample_idx]
  192. @property
  193. def cummulative_sizes(self):
  194. warnings.warn("cummulative_sizes attribute is renamed to "
  195. "cumulative_sizes", DeprecationWarning, stacklevel=2)
  196. return self.cumulative_sizes
  197. class ChainDataset(IterableDataset):
  198. r"""Dataset for chaining multiple :class:`IterableDataset` s.
  199. This class is useful to assemble different existing dataset streams. The
  200. chaining operation is done on-the-fly, so concatenating large-scale
  201. datasets with this class will be efficient.
  202. Args:
  203. datasets (iterable of IterableDataset): datasets to be chained together
  204. """
  205. def __init__(self, datasets: Iterable[Dataset]) -> None:
  206. super(ChainDataset, self).__init__()
  207. self.datasets = datasets
  208. def __iter__(self):
  209. for d in self.datasets:
  210. assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
  211. for x in d:
  212. yield x
  213. def __len__(self):
  214. total = 0
  215. for d in self.datasets:
  216. assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
  217. total += len(d) # type: ignore[arg-type]
  218. return total
  219. class Subset(Dataset[T_co]):
  220. r"""
  221. Subset of a dataset at specified indices.
  222. Args:
  223. dataset (Dataset): The whole Dataset
  224. indices (sequence): Indices in the whole set selected for subset
  225. """
  226. dataset: Dataset[T_co]
  227. indices: Sequence[int]
  228. def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None:
  229. self.dataset = dataset
  230. self.indices = indices
  231. def __getitem__(self, idx):
  232. if isinstance(idx, list):
  233. return self.dataset[[self.indices[i] for i in idx]]
  234. return self.dataset[self.indices[idx]]
  235. def __len__(self):
  236. return len(self.indices)
  237. def random_split(dataset: Dataset[T], lengths: Sequence[int],
  238. generator: Optional[Generator] = default_generator) -> List[Subset[T]]:
  239. r"""
  240. Randomly split a dataset into non-overlapping new datasets of given lengths.
  241. Optionally fix the generator for reproducible results, e.g.:
  242. >>> random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42))
  243. Args:
  244. dataset (Dataset): Dataset to be split
  245. lengths (sequence): lengths of splits to be produced
  246. generator (Generator): Generator used for the random permutation.
  247. """
  248. # Cannot verify that dataset is Sized
  249. if sum(lengths) != len(dataset): # type: ignore[arg-type]
  250. raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
  251. indices = randperm(sum(lengths), generator=generator).tolist()
  252. return [Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)]