torch.utils.data
PyTorch 数据加载实用程序的核心是 torch.utils.data.DataLoader
类。它表示一个可以迭代的 dataset,并提供了相应的支持功能。
-
批量加载,
-
自动内存 pinned. 注:此处根据上下文更倾向于“自动内存固定”,但由于直接翻译可能有细微差异,保持原回答以"a href="#memory-pinning">自动内存固定"为准。若需强调技术准确性,“自动内存固定”更为贴切。由于要求自然通顺和易于理解,维持原文表述较为适宜。因此,返回原文: 自动内存固定.
这些选项通过 DataLoader
的构造函数参数进行配置,其签名如下:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, *, prefetch_factor=2, persistent_workers=False)
以下章节详细介绍了这些选项的效果和用法。
数据集类型
DataLoader
构造函数最重要的参数是 dataset
,它表示一个用于加载数据的数据集对象。PyTorch 支持两种不同类型的数据集:
映射式数据集
一种图谱式数据集实现了__getitem__()
和__len__()
协议,并表示从(可能是非整数的)索引/键到数据样本的映射。
例如,当通过 dataset[idx]
访问这样的数据集时,可以从磁盘上的文件夹中读取第 idx
个图像及其对应的标签。
更多详情请参见Dataset
。
迭代式数据集
迭代式数据集是IterableDataset
子类的一个实例,它实现了__iter__()
协议,并表示一个可迭代的数据样本集合。这种类型的数据集特别适合于随机读取成本高昂或不可能的情况,以及批次大小取决于获取数据的场景。
例如,当调用iter(dataset)
时,这样的数据集可以返回一个数据流,这个数据流可以从数据库、远程服务器或甚至是实时生成的日志中读取数据。
更多详情请参见IterableDataset
。
注意
当使用IterableDataset
与多进程数据加载时,同一个数据集对象会在每个工作进程中被复制。为了防止数据重复,这些副本需要进行不同的配置。请参阅IterableDataset
文档以了解如何实现这一点。
数据加载顺序及 Sampler
对于可迭代风格的数据集,数据加载顺序完全由用户定义的可迭代对象控制。这使得实现块读取和动态批量大小(例如通过在每个时间点生成批处理样本)更加容易。
本节的其余部分将讨论映射式数据集的情况。 torch.utils.data.Sampler
类用于指定在数据加载过程中使用的索引序列。它们表示可迭代的对象,遍历数据集中每个索引的集合。例如,在常见的随机梯度下降(SGD)情况下,一个Sampler
可以随机排列一组索引并逐个返回,或者一次返回一小部分用于小批量 SGD。
根据传递给DataLoader
的shuffle
参数,将自动构建一个顺序或随机排列的采样器。或者,用户可以使用sampler
参数指定一个自定义的Sampler
对象,在每次调用时生成下一个要获取的数据索引或键。
可以通过传递一个自定义的 Sampler
来生成一批索引列表,并通过设置 batch_sampler
参数来实现。还可以通过设置 batch_size
和 drop_last
参数启用自动批量处理功能。有关此功能的更多详细信息,请参见下一节。
注意
无论是sampler
还是 batch_sampler
都不适用于迭代式数据集,因为这类数据集没有键或索引的概念。
加载批量和非批量数据
DataLoader
通过参数 batch_size
、drop_last
、batch_sampler
和 collate_fn
(该参数有一个默认函数)支持自动将获取的单个数据样本汇集到批次中。
自动批处理(默认设置)
这是最常见的场景,对应于获取一个小批量数据,并将这些数据整理成批处理样本,每个样本包含一个维度作为批量维度(通常是最前面的那个维度)的张量。
当 batch_size
(默认为 1
) 不是 None
时,数据加载器会生成批量样本而不是单个样本。通过设置 batch_size
和 drop_last
参数,可以指定数据加载器如何获取数据集键的批次。对于 map 风格的数据集,用户还可以选择使用 batch_sampler
来一次生成一个键列表。
注意
参数 batch_size
和 drop_last
用于从 sampler
构造一个 batch_sampler
。对于 map 式数据集,sampler
可以由用户提供或基于 shuffle
参数构建。而对于 iterable 式数据集,sampler
则是一个假定的无限循环采样器。有关更多详细信息,请参见 本节。
注意
当从支持可迭代风格的数据集并使用多进程数据加载时,drop_last
参数会丢弃每个工作进程中数据集副本的最后一个非完整批次。
在使用sampler提供的索引获取样本列表之后,会用作collate_fn
参数的函数来将这些样本列表整理成批次。
在这种情况下,从地图式的数据集进行加载与以下操作大致相同:
for indices in batch_sampler: yield collate_fn([dataset[i] for i in indices])
从迭代式数据集加载数据大致等同于:
dataset_iter = iter(dataset) for indices in batch_sampler: yield collate_fn([next(dataset_iter) for _ in indices])
可以使用自定义的 collate_fn
来定制数据整理,例如将序列数据填充到批次中的最大长度。有关更多关于 collate_fn
的信息,请参见本节。
关闭自动批处理
在某些情况下,用户可能希望手动处理数据集中的批处理,或者简单地加载单个样本。例如,直接从数据库批量读取或连续读取内存块的成本更低,或者批量大小取决于数据,或者程序设计为针对单一样本工作。在这种情况下,不使用自动批处理(其中collate_fn
用于整理样本)可能更好,而是让数据加载器直接返回dataset
对象的每个成员。
当batch_size
和batch_sampler
都为None
(batch_sampler
的默认值已经是None
)时,自动批量处理被禁用。从dataset
获取的每个样本将使用传递给collate_fn
参数的函数进行处理。
当自动批处理被禁用时,默认的 collate_fn
只会将 NumPy 数组转换为 PyTorch 张量,并保持其他数据不变。
在这种情况下,从地图式的数据集进行加载与以下操作大致相同:
for index in sampler: yield collate_fn(dataset[index])
从迭代式数据集加载数据大致等同于:
for data in iter(dataset): yield collate_fn(data)
关于collate_fn
的更多内容,请参见本节。
与 collate_fn
一起工作
在自动批量处理启用或禁用的情况下,collate_fn
的使用方式会有所不同。
当自动批量处理被禁用时,collate_fn
会针对每个单独的数据样本进行调用,并且数据加载器迭代器会输出结果。在这种情况下,默认的 collate_fn
简单地将 NumPy 数组转换为 PyTorch 张量。
当自动批处理启用时,collate_fn
在每个时间点会接收一个数据样本列表。它负责将这些输入样本整理成一批数据,并返回给数据加载器迭代器。本节的其余部分描述了默认 collate_fn
(default_collate()
)的行为。
例如,如果每个数据样本包含一个三通道图像和一个整数类别标签(即数据集中的每个元素返回一个元组 (image, class_index)
),默认的 collate_fn
会将这些元组列表整理成一个包含批次图像张量和批次类别标签张量的元组。具体来说, 默认的 collate_fn
具有以下属性:
-
它总是将一个新的维度添加作为批次维度。
-
它会自动将 NumPy 数组和 Python 数值转换为 PyTorch 张量。
-
它保留数据结构,例如,如果每个样本是一个字典,则输出一个具有相同键集的字典,但值为批量张量(如果值不能转换为张量,则使用列表)。对于
list
、tuple
和namedtuple
等也是如此。
用户可以使用自定义的 collate_fn
函数来实现各种自定义批处理操作,比如在非第一维度上进行批处理、填充不同长度的序列,或者支持自定义的数据类型。
如果发现 DataLoader
的输出维度或类型与预期不符,你需要检查一下 collate_fn
。
单进程与多进程数据加载
默认情况下,一个 DataLoader
使用单进程数据加载。
在一个 Python 进程中,全局解释器锁(GIL)阻止了 Python 代码在多线程之间真正地完全并行化。为了避免数据加载阻塞计算代码,PyTorch 提供了一个简单的设置方法:将参数 num_workers
设置为一个正整数,从而实现多进程数据加载。
单进程数据加载(默认设置)
在此模式下,数据获取在初始化DataLoader
的同一进程中完成,因此可能会阻塞计算。然而,在资源(如共享内存、文件描述符)有限或整个数据集较小且可以完全加载到内存中时,此模式可能是首选。此外,单进程加载通常提供更易读的错误跟踪信息,对于调试非常有用。
多进程数据加载
将参数 num_workers
设置为正整数,可以启用多进程数据加载,并指定加载器的工作进程数量。
警告
经过几次迭代后,加载器工作进程将消耗与父进程中所有被工作进程访问的 Python 对象相同的 CPU 内存。如果数据集包含大量数据(例如,在构建数据集时加载非常大的文件名列表)和/或使用了大量工作进程(总体内存使用量为 number of workers * size of parent process
),这可能会成为一个问题。最简单的解决方法是将 Python 对象替换为非引用计数表示形式,例如 Pandas、Numpy 或 PyArrow 对象。有关此问题发生的原因以及如何绕过这些问题的示例代码,请参阅 issue #13246。
在此模式下,每次创建 DataLoader
的迭代器(例如当你调用 enumerate(dataloader)
时),会启动 num_workers
个工作进程。此时,dataset
、collate_fn
和 worker_init_fn
被传递给每个工作进程,在那里它们被用来初始化并获取数据。这意味着数据集的访问及其内部 IO 和转换(包括 collate_fn
)都在这些工作进程中运行。
torch.utils.data.get_worker_info()
在工作进程中返回各种有用的信息(包括工作进程ID、数据集副本和初始种子等),并在主进程中返回 None
。用户可以在数据集中使用此函数或 worker_init_fn
来单独配置每个数据集副本,并判断代码是否在工作进程中运行。例如,这在对数据集进行分片时特别有帮助。
对于 map-style 数据集,主进程使用 sampler
生成索引并将其发送到工作进程。因此,洗牌随机化在主进程中完成,并通过分配要加载的索引来指导数据加载。
对于可迭代风格的数据集,由于每个工作进程都会获得一个dataset
对象的副本,简单的多进程加载通常会导致数据重复。通过使用torch.utils.data.get_worker_info()
和/或worker_init_fn
,用户可以独立配置每个副本。(参见IterableDataset
文档以了解如何实现这一点。)出于类似原因,在多进程加载中,drop_last
参数会丢弃每个工作进程的可迭代风格数据集副本中的最后一个非完整批次。
当迭代结束或迭代器被垃圾回收时,工作者将被关闭。
警告
由于在多进程环境中使用 CUDA 和共享 CUDA 张量存在许多复杂性(参见CUDA 在多进程中),通常不建议返回 CUDA 张量。相反,我们推荐使用 自动内存固定 (即设置pin_memory=True
),以实现数据快速传输到支持 CUDA 的 GPU。
平台特定行为
因为工作者依赖于 Python 的 multiprocessing
模块,所以在 Windows 上的启动行为与 Unix 系统有所不同。
-
在 Unix 系统上,
fork()
是默认的multiprocessing
启动方法。使用fork()
,子进程可以直接通过克隆的地址空间访问dataset
和 Python 参数函数。 -
在 Windows 或 MacOS 上,
spawn()
是默认的multiprocessing
启动方法。使用spawn()
,会启动另一个解释器来运行你的主脚本,并通过pickle
序列化传递内部工作函数。这个工作函数接收dataset
、collate_fn
和其他参数。
这种独立的序列化意味着,在使用多进程数据加载时,你需要采取两步措施来确保与 Windows 的兼容性:
-
将你主脚本的大部分代码包裹在
if __name__ == '__main__':
块中,以确保它不会再次运行(很可能生成错误),当每个工作进程启动时。你可以在这里放置数据集和DataLoader
实例的创建逻辑,因为这些不需要在工作进程中重新执行。 -
请将任何自定义的
)collate_fn
,worker_init_fn
或dataset
代码声明为顶级定义,置于__main__
检查之外。这样可以确保它们在工作进程中可用。(因为函数仅通过引用进行序列化,而不是通过字节码。
多进程数据加载中的随机性
默认情况下,每个工作进程将设置其 PyTorch 种子为 base_seed + worker_id
,其中 base_seed
是主进程中使用随机数生成器(强制消耗一个 RNG 状态)或指定的 generator
生成的一个长整型值。然而,在初始化工作进程时,其他库可能会重复设置种子,导致每个工作进程返回相同的随机数。(参见 FAQ 中的此部分)。
在worker_init_fn
中,你可以通过torch.utils.data.get_worker_info().seed
或torch.initial_seed()
访问为每个工作者设置的PyTorch种子,并在数据加载之前使用该种子来初始化其他库。
内存 pinning
当主机到 GPU 的数据复制操作起源于固定(页锁定)内存时,速度会快得多。有关何时以及如何一般性地使用固定内存的详细信息,请参见使用固定内存缓冲区。
在数据加载时,将pin_memory=True
传递给DataLoader
会自动将获取的数据张量放置在固定内存中,并且可以更快地传输到支持 CUDA 的 GPU。
默认的内存固定逻辑仅识别张量及其包含映射和可迭代对象。默认情况下,如果固定逻辑遇到由自定义类型组成的批次(这种情况通常发生在你使用返回自定义批次类型的collate_fn
时),或者如果你的每个批次元素都是自定义类型,固定逻辑将不会识别它们,并且会直接返回该批次(或那些元素)而不进行内存固定。为了启用自定义批次或数据类型的内存固定,请在你的自定义类型上实现一个pin_memory()
方法。
请参看以下示例。
示例:
class SimpleCustomBatch: def __init__(self, data): transposed_data = list(zip(*data)) self.inp = torch.stack(transposed_data[0], 0) self.tgt = torch.stack(transposed_data[1], 0) # custom memory pinning method on custom type def pin_memory(self): self.inp = self.inp.pin_memory() self.tgt = self.tgt.pin_memory() return self def collate_wrapper(batch): return SimpleCustomBatch(batch) inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) dataset = TensorDataset(inps, tgts) loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper, pin_memory=True) for batch_ndx, sample in enumerate(loader): print(sample.inp.is_pinned()) print(sample.tgt.is_pinned())
- 类torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device='')[源代码]
-
数据加载器结合数据集和采样器,提供一个可以遍历给定数据集的对象。
DataLoader
同时支持 map 式和可迭代式数据集,提供单进程或多进程加载功能,并允许自定义加载顺序以及可选的自动批量处理(整理)和内存固定。参见
torch.utils.data
文档页面以获取更多信息。- 参数
-
-
dataset (Dataset) — 数据集,从中加载数据。
-
batch_size (int, 可选) – 每批加载的样本数量(默认为
1
)。 -
shuffle (bool, 可选) – 将其设置为
True
,以便在每个 epoch 中重新洗牌数据(默认值:False
)。 -
sampler (Sampler 或 Iterable, 可选) – 定义从数据集抽取样本的策略。可以是任何实现了
__len__
的Iterable
对象。如果指定了 sampler,则不能指定shuffle
。 -
batch_sampler (Sampler 或 Iterable, 可选) – 类似于
sampler
,但每次返回一批索引。与batch_size
、shuffle
、sampler
和drop_last
互斥。 -
num_workers (int, 可选) – 用于数据加载的子进程数量。如果设置为
0
,则表示数据将在主进程中加载。(默认值:0
) -
collate_fn (Callable, 可选) – 合并样本列表以形成小批次的张量。在从映射式数据集进行批量加载时使用。
-
pin_memory (bool, 可选) – 如果设置为
True
,数据加载器会在返回张量之前将其复制到设备/CUDA固定内存中。如果你的数据元素是自定义类型,或者你的collate_fn
返回的批次是自定义类型,请参见下面的例子。 -
drop_last (bool, optional) – 如果设置为
True
,当数据集大小不能被批量大小整除时,会丢弃最后一个不完整的批次。如果设置为False
且数据集大小不能被批量大小整除,则最后一个批次的大小会小于其他批次。(默认值:False
) -
timeout (数值型, 可选) – 如果为正数,则表示从工作者收集一批数据的超时时间。该值应始终是非负数。(默认值:
0
) -
worker_init_fn (Callable, optional) – 如果不为
None
,这将在每个工作进程子进程中被调用,并以 worker id(范围在[0, num_workers - 1]
的整数)作为输入,在播种之后和数据加载之前。 (默认值:None
) -
multiprocessing_context (str 或 multiprocessing.context.BaseContext, 可选) – 如果为
None
,将使用操作系统的默认多进程上下文。 (默认:None
) -
generator (torch.Generator, optional) – 如果不为
None
,此随机数生成器将被 RandomSampler 用于生成随机索引,并且会被多进程使用以生成工人的base_seed
。(默认值:None
) -
prefetch_factor (int, 可选, 关键字参数) – 每个工作者提前加载的批次数量。例如,
2
表示所有工作者将总共预取 2 * num_workers 批次。(默认值取决于设置的 num_workers 值:如果 num_workers=0,默认为None
;否则,默认为2
)。 -
persistent_workers (bool, 可选) – 如果为
True
,数据加载器在一次消费完数据集后不会关闭工作进程。这样可以保持 Dataset 实例的存活状态。(默认值:False
) -
pin_memory_device (str, optional) – 当
pin_memory
为True
时,指定要将pin_memory
设置到的设备。
-
警告
如果使用了
spawn
启动方法,worker_init_fn
不能是一个无法序列化的对象(例如lambda函数)。有关PyTorch中多进程的更多最佳实践,请参见Multiprocessing best practices。警告
len(dataloader)
的启发式方法基于所使用的采样器的长度。当dataset
是一个IterableDataset
时,它会根据len(dataset) / batch_size
返回一个估计值,并且根据drop_last
进行适当的舍入处理。无论多进程加载配置如何,这代表了 PyTorch 可以做出的最佳猜测,因为 PyTorch 信任用户提供的dataset
代码能够正确处理多进程加载以避免重复数据。然而,如果分片导致多个工作进程拥有不完整的最后一个批次,此估计仍然可能不准确,因为(1)原本完整的批次会被分成多个批次,以及(2)当
drop_last
被设置时,可能会丢弃超过一个批次的样本。不幸的是,PyTorch 通常无法检测到这种情况。有关这两种类型的数据集以及
IterableDataset
如何与多进程数据加载交互的更多详细信息,请参阅数据集类型。警告
参见重现性,以及数据加载器工作进程返回相同的随机数和多进程数据加载中的随机性的相关说明。
- 类torch.utils.data.Dataset[源代码]
-
代表一个
Dataset
的抽象类。所有表示从键到数据样本映射的数据集都应该继承它。所有子类应重写
__getitem__()
方法,以支持根据给定的键获取数据样本。子类还可以选择性地重写__len__()
方法,该方法返回数据集的大小,许多Sampler
实现和DataLoader
默认选项都依赖于此。子类还可以选择性地实现__getitems__()
方法,以加快批量样本的加载速度。此方法接受一批样本索引列表并返回相应的样本列表。注意
DataLoader
默认构造一个索引采样器,生成整数索引。为了使其与具有非整数索引/键的 map 式数据集一起工作,需要提供自定义采样器。
- 类torch.utils.data.IterableDataset[源代码]
-
一个可迭代的数据集。
所有表示数据样本序列的数据集都应该继承该类。当数据来自流式传输时,这种类型的数据集特别有用。
所有子类都应该重写
__iter__()
方法,使其返回该数据集中样本的迭代器。当子类与
DataLoader
一起使用时,数据集中的每一项将由DataLoader
迭代器生成。如果num_workers > 0
,每个工作进程都会有自己的数据集对象副本,因此通常需要独立配置这些副本以避免返回重复的数据。get_worker_info()
在工作进程中调用时会提供有关该工作的信息。此函数可以在数据集的__iter__()
方法或DataLoader
的worker_init_fn
选项中使用,以修改每个副本的行为。示例 1:在
__iter__()
中将工作负载分发给所有工作者:>>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... worker_info = torch.utils.data.get_worker_info() ... if worker_info is None: # single-process data loading, return the full iterator ... iter_start = self.start ... iter_end = self.end ... else: # in a worker process ... # split workload ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... iter_start = self.start + worker_id * per_worker ... iter_end = min(iter_start + per_worker, self.end) ... return iter(range(iter_start, iter_end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [tensor([3]), tensor([4]), tensor([5]), tensor([6])] >>> # Mult-process loading with two worker processes >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [tensor([3]), tensor([5]), tensor([4]), tensor([6])] >>> # With even more workers >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12))) [tensor([3]), tensor([5]), tensor([4]), tensor([6])]
示例 2:使用
worker_init_fn
将工作负载分配给所有工作者:>>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... return iter(range(self.start, self.end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [3, 4, 5, 6] >>> >>> # Directly doing multi-process loading yields duplicate data >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [3, 3, 4, 4, 5, 5, 6, 6] >>> # Define a `worker_init_fn` that configures each dataset copy differently >>> def worker_init_fn(worker_id): ... worker_info = torch.utils.data.get_worker_info() ... dataset = worker_info.dataset # the dataset copy in this worker process ... overall_start = dataset.start ... overall_end = dataset.end ... # configure the dataset to only process the split workload ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... dataset.start = overall_start + worker_id * per_worker ... dataset.end = min(dataset.start + per_worker, overall_end) ... >>> # Mult-process loading with the custom `worker_init_fn` >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn))) [3, 5, 4, 6] >>> # With even more workers >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn))) [3, 4, 5, 6]
- 类torch.utils.data.TensorDataset(*tensors)[源代码]
-
将张量封装为数据集。
每个样本将通过沿着第一维度索引张量来进行检索。
- 参数
-
*tensors (Tensor) – 第一维度大小相同的张量。
- 类torch.utils.data.StackDataset(*args, **kwargs)[源代码]
-
数据集由多个数据集堆叠而成。
此类有助于将复杂的输入数据的不同部分组合在一起,这些数据以数据集的形式提供。
示例
>>> images = ImageDataset() >>> texts = TextDataset() >>> tuple_stack = StackDataset(images, texts) >>> tuple_stack[0] == (images[0], texts[0]) >>> dict_stack = StackDataset(image=images, text=texts) >>> dict_stack[0] == {'image': images[0], 'text': texts[0]}
- 类torch.utils.data.ConcatDataset(datasets)[源代码]
-
数据集是由多个数据集组合而成的。
此类用于组合不同的现有数据集。- 参数
-
datasets (序列) – 需要拼接的 dataset 列表
- 类torch.utils.data.ChainDataset(datasets)[源代码]
-
用于链式连接多个
IterableDataset
的数据集。此类有助于组合不同的现有数据流。连接操作会实时进行,因此使用此类来连接大规模数据集将会非常高效。
- 参数
-
datasets ( IterableDataset 的可迭代对象) – 需要链接在一起的数据集
- 类torch.utils.data.Subset(dataset, indices)[源代码]
-
数据集在指定索引的子集。
- 参数
-
-
dataset (Dataset) – 数据集的全部内容
-
indices (序列) – 子集在完整集合中所选的索引
-
- torch.utils.data._utils.collate.collate(batch, *, collate_fn_map=None)[源代码]
-
一个用于处理每个批次中集合类型元素的通用 collate 函数。
此函数还允许对特定元素类型进行处理,并通过default_collate_fn_map 提供了针对张量、numpy数组、数字和字符串的默认collate函数。
- 参数
示例
>>> def collate_tensor_fn(batch, *, collate_fn_map): ... # Extend this function to handle batch of tensors ... return torch.stack(batch, 0) >>> def custom_collate(batch): ... collate_map = {torch.Tensor: collate_tensor_fn} ... return collate(batch, collate_fn_map=collate_map) >>> # Extend `default_collate` by in-place modifying `default_collate_fn_map` >>> default_collate_fn_map.update({torch.Tensor: collate_tensor_fn})
注意
每个 collate 函数都需要一个位置参数来指定批次,并且需要一个关键字参数来指定 collate 函数的字典,即 collate_fn_map。
- torch.utils.data.default_collate(batch)[源代码]
-
接收一批数据,并将其元素放入一个具有额外批次大小维度的张量中。
确切的输出类型可以是一个
torch.Tensor
,一个包含多个torch.Tensor
的序列(Sequence),或者一个包含torch.Tensor
的集合。具体取决于输入类型,输出可以保持不变。当在DataLoader
中定义了batch_size或batch_sampler时,这被用作默认的批处理函数。这里是一般输入类型到输出类型的映射,映射基于批次中元素的类型:
-
torch.Tensor
->torch.Tensor
(增加了一个外层的批次大小维度) -
NumPy 数组 ->
torch.Tensor
-
float ->
torch.Tensor
-
int ->
torch.Tensor
-
str -> str (保持不变)
-
bytes -> bytes(保持不变)
-
Mapping[K, V_i] -> Mapping[K, default_collate([V_1, V_2, ...])]
-
NamedTuple[V1_i, V2_i, …] -> NamedTuple[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …]
-
Sequence[V1_i, V2_i, …] -> Sequence[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …]
- 参数
-
batch – 单个待合并的批次
示例
>>> # Example with a batch of `int`s: >>> default_collate([0, 1, 2, 3]) tensor([0, 1, 2, 3]) >>> # Example with a batch of `str`s: >>> default_collate(['a', 'b', 'c']) ['a', 'b', 'c'] >>> # Example with `Map` inside the batch: >>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}]) {'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])} >>> # Example with `NamedTuple` inside the batch: >>> Point = namedtuple('Point', ['x', 'y']) >>> default_collate([Point(0, 0), Point(1, 1)]) Point(x=tensor([0, 1]), y=tensor([0, 1])) >>> # Example with `Tuple` inside the batch: >>> default_collate([(0, 1), (2, 3)]) [tensor([0, 2]), tensor([1, 3])] >>> # Example with `List` inside the batch: >>> default_collate([[0, 1], [2, 3]]) [tensor([0, 2]), tensor([1, 3])] >>> # Two options to extend `default_collate` to handle specific type >>> # Option 1: Write custom collate function and invoke `default_collate` >>> def custom_collate(batch): ... elem = batch[0] ... if isinstance(elem, CustomType): # Some custom condition ... return ... ... else: # Fall back to `default_collate` ... return default_collate(batch) >>> # Option 2: In-place modify `default_collate_fn_map` >>> def collate_customtype_fn(batch, *, collate_fn_map=None): ... return ... >>> default_collate_fn_map.update(CustomType, collate_customtype_fn) >>> default_collate(batch) # Handle `CustomType` automatically
-
- torch.utils.data.default_convert(data)[源代码]
-
将每个 NumPy 数组元素转化为
torch.Tensor
。如果输入是一个Sequence、Collection或Mapping,它会尝试将每个元素转换为
torch.Tensor
。如果输入不是NumPy数组,则保持不变。当在DataLoader
中既没有定义batch_sampler也没有定义batch_size时,这被用作默认的批处理函数。输入类型到输出类型的映射一般与
default_collate()
类似。更多细节请参阅相关描述。- 参数
-
data — 一个待转换的数据点
示例
>>> # Example with `int` >>> default_convert(0) 0 >>> # Example with NumPy array >>> default_convert(np.array([0, 1])) tensor([0, 1]) >>> # Example with NamedTuple >>> Point = namedtuple('Point', ['x', 'y']) >>> default_convert(Point(0, 0)) Point(x=0, y=0) >>> default_convert(Point(np.array(0), np.array(0))) Point(x=tensor(0), y=tensor(0)) >>> # Example with List >>> default_convert([np.array([0, 1]), np.array([2, 3])]) [tensor([0, 1]), tensor([2, 3])]
- torch.utils.data.get_worker_info()[源代码]
-
返回当前
DataLoader
迭代器 worker 进程的信息。在这个工作进程中调用时,它将返回一个包含以下属性的对象:
-
id
: 当前工人的ID。 -
num_workers
: 总工作人数。 -
seed
: 当前工作者进程使用的随机种子值,该值由主进程的随机数生成器和工作者ID共同确定。更多详情请参阅DataLoader
的文档。 -
dataset
: 此进程中的数据集对象的副本。需要注意的是,其他进程中的该对象会有所不同。
在主进程中调用时,它将返回
None
。注意
当在传递给
DataLoader
的worker_init_fn
中使用时,此方法可以用于以不同的方式初始化每个工作进程。例如,可以使用worker_id
来配置dataset
对象仅读取分片数据集中的特定部分,或者使用seed
为在数据集中使用的其他库进行随机种子设置。- 返回类型
-
- torch.utils.data.random_split(dataset, lengths, generator=<torch._C.Generator object>)[源代码]
-
将数据集随机分成若干个指定长度的不重叠新数据集。
如果给出了一组总和为1的分数,系统会自动计算每个分数对应的长度,公式为floor(frac * len(dataset))。
计算长度后,如果有余数,就会依次轮流分配这些余数到各个长度上,直到所有余数都被分配完毕。
可选地固定生成器,以便获得可重复的结果,例如:
示例
>>> generator1 = torch.Generator().manual_seed(42) >>> generator2 = torch.Generator().manual_seed(42) >>> random_split(range(10), [3, 7], generator=generator1) >>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2)
- 类torch.utils.data.Sampler(data_source=None)[源代码]
-
所有_sampler_的基类。
每个 Sampler 子类都必须实现一个
__iter__()
方法,用于遍历数据集元素的索引或索引列表(批次),并且可以提供一个__len__()
方法来返回迭代器的长度。- 参数
-
data_source (Dataset) – 此参数未被使用,将在 2.2.0 版本中移除。你可能仍然有自定义实现依赖于它。
示例
>>> class AccedingSequenceLengthSampler(Sampler[int]): >>> def __init__(self, data: List[str]) -> None: >>> self.data = data >>> >>> def __len__(self) -> int: >>> return len(self.data) >>> >>> def __iter__(self) -> Iterator[int]: >>> sizes = torch.tensor([len(x) for x in self.data]) >>> yield from torch.argsort(sizes).tolist() >>> >>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]): >>> def __init__(self, data: List[str], batch_size: int) -> None: >>> self.data = data >>> self.batch_size = batch_size >>> >>> def __len__(self) -> int: >>> return (len(self.data) + self.batch_size - 1) // self.batch_size >>> >>> def __iter__(self) -> Iterator[List[int]]: >>> sizes = torch.tensor([len(x) for x in self.data]) >>> for batch in torch.chunk(torch.argsort(sizes), len(self)): >>> yield batch.tolist()
注意
__len__()
方法不是严格要求在DataLoader
中实现的,但在涉及DataLoader
长度的任何计算中是期望被实现的。
- 类torch.utils.data.SequentialSampler(data_source)[源代码]
-
按照固定的顺序依次抽取元素。
- 参数
-
data_source (Dataset) – 需要从中采样的数据集
- 类torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None, generator=None)[源代码]
-
随机抽取元素。如果没有放回,则从洗乱顺序的数据集中抽取。
如果有放回的情况,用户可以指定
num_samples
来抽取样本。
- 类torch.utils.data.SubsetRandomSampler(indices, generator=None)[源代码]
-
从给定的索引列表中随机抽取元素,且不重复。
- 参数
-
-
indices (序列) – 索引的序列
-
generator (Generator) – 用于样本生成的生成器。
-
- 类torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True, generator=None)[源代码]
-
根据给定的概率(权重),从
[0, ..., len(weights) - 1]
中抽取元素。- 参数
示例
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) [4, 4, 1, 4, 5] >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) [0, 1, 4, 3, 2]
- 类torch.utils.data.BatchSampler(sampler, batch_size, drop_last)[源代码]
-
封装另一个采样器以生成一批索引。
- 参数
示例
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
- 类 torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False) [源代码]]
-
一个限制数据加载为数据集子集的采样器。
它尤其与
torch.nn.parallel.DistributedDataParallel
结合使用时非常有用。在这种情况下,每个进程可以将一个DistributedSampler
实例作为DataLoader
的采样器,并加载专属于该进程的原始数据集的一个子集。注意
假定数据集的大小是固定的,而且它的任何实例总是以相同的顺序返回相同的元素。
- 参数
-
-
dataset (Dataset) — 用于采样的数据集。
-
num_replicas (int, 可选) – 参与分布式训练的进程数量。默认情况下,
world_size
从当前分布式组中获取。 -
rank (int, 可选) – 当前进程在
num_replicas
中的排名。默认情况下,rank
会从当前的分布式组中获取。 -
shuffle (bool, 可选) – 如果为
True
(默认值),采样器将对索引进行随机排列。 -
seed (int, optional) – 在
shuffle=True
时用于打乱采样的随机种子。此数字在分布式组中的所有进程中应保持一致。默认值:0
。 -
drop_last (bool, optional) – 如果为
True
,采样器将丢弃数据尾部以使其在副本数量之间均匀划分。如果为False
,采样器将在副本之间添加额外的索引以使数据均匀划分。默认值:False
。
-
警告
在分布式模式下,每个 epoch 开始时,在创建
DataLoader
迭代器之前调用set_epoch()
方法是必要的,以确保 shuffle 跨多个 epoch 正确工作。否则,每次都会使用相同的顺序。示例:
>>> sampler = DistributedSampler(dataset) if is_distributed else None >>> loader = DataLoader(dataset, shuffle=(sampler is None), ... sampler=sampler) >>> for epoch in range(start_epoch, n_epochs): ... if is_distributed: ... sampler.set_epoch(epoch) ... train(loader)