在PyTorch中,Dataset和DataLoader是构建数据管道的核心组件。Dataset封装数据源及访问方式,DataLoader负责批量加载、打乱、多进程并行等。本文将从自定义数据集实现、内置数据集加载、DataLoader参数详解到实战案例,帮助读者搭建高效的数据流水线。
## 自定义数据集:映射式与可迭代式
自定义Dataset必须继承torch.utils.data.Dataset并实现三个方法:
- __init__:初始化数据路径、标签、变换等;
- __len__:返回数据集样本总数;
- __getitem__(self, idx):根据索引返回样本(特征和标签)。
示例:
- from torch.utils.data import Dataset
- class MyDataset(Dataset):
- def __init__(self, data, labels):
- self.data = data
- self.labels = labels
- def __len__(self):
- return len(self.data)
- def __getitem__(self, idx):
- return self.data[idx], self.labels[idx]
复制代码
PyTorch支持两种数据集风格:
- 映射式(Map-Style):实现__getitem__和__len__,通过索引随机访问,适用于数据能全部放入内存索引结构的场景,支持shuffle。
- 可迭代式(Iterable-Style):继承IterableDataset,实现__iter__返回迭代器,适用于数据流式读取或无法统计长度的场景,不能使用len()和shuffle。
## 内置数据集加载
torchvision.datasets提供多种内置数据集,如FashionMNIST。加载时需指定root路径、train标志、transform(如ToTensor)和download是否自动下载。
- import torch
- from torchvision import datasets
- from torchvision.transforms import ToTensor
- training_data = datasets.FashionMNIST(
- root="./data",
- train=True,
- download=True,
- transform=ToTensor()
- )
- test_data = datasets.FashionMNIST(
- root="./data",
- train=False,
- download=True,
- transform=ToTensor()
- )
复制代码
访问样本:training_data[sample_idx]返回(图像张量, 标签),标签映射为0~9对应的服装类别。
## DataLoader参数详解
DataLoader将DataSet封装为迭代器,核心API及参数含义:
- DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
- batch_sampler=None, num_workers=0, collate_fn=None,
- pin_memory=False, drop_last=False, timeout=0,
- worker_init_fn=None, generator=None,
- prefetch_factor=2, persistent_workers=False)
复制代码
- dataset:需要加载的Dataset对象。
- batch_size:每个批次样本数,默认1。
- shuffle:每个epoch开始时是否打乱数据(仅映射式有效,使用RandomSampler)。
- sampler:自定义索引采样器,若指定则shuffle必须为False。
- batch_sampler:返回批次索引,与batch_size、shuffle、sampler互斥。
- num_workers:数据加载子进程数,0表示主进程,>0可加速。
- collate_fn:将样本列表合并成批次,默认按第0维堆叠;样本结构不一致时需自定义。
- pin_memory:将张量复制到CUDA固定内存,加速CPU到GPU传输(仅CUDA)。
- drop_last:丢弃最后一个不完整批次,BatchNorm等需固定批次大小时建议True。
- timeout:获取批次超时时间(秒)。
- worker_init_fn:worker初始化函数,常用于设置随机种子。
- generator:伪随机数生成器,保证可复现。
- prefetch_factor:每个worker预加载的batch数(默认2)。
- persistent_workers:是否保持worker进程存活以加速后续epoch。
## 实战案例:鸢尾花数据集
使用sklearn的iris数据集,实现数据归一化、随机打乱,并划分训练/验证/测试集,通过DataLoader批量加载。
- import torch
- from sklearn.datasets import load_iris
- from torch.utils.data import Dataset, DataLoader
- def load_data(shuffle=True):
- x = torch.tensor(load_iris().data)
- y = torch.tensor(load_iris().target)
- # 数据归一化
- x_min = torch.min(x, dim=0).values
- x_max = torch.max(x, dim=0).values
- x = (x - x_min) / (x_max - x_min)
- if shuffle:
- idx = torch.randperm(x.shape[0])
- x = x[idx]
- y = y[idx]
- return x, y
- class IrisDataset(Dataset):
- def __init__(self, mode='train', num_train=120, num_dev=15):
- super().__init__()
- x, y = load_data(shuffle=True)
- if mode == 'train':
- self.x, self.y = x[:num_train], y[:num_train]
- elif mode == 'dev':
- self.x, self.y = x[num_train:num_train+num_dev], y[num_train:num_train+num_dev]
- else:
- self.x, self.y = x[num_train+num_dev:], y[num_train+num_dev:]
- def __getitem__(self, idx):
- return self.x[idx], self.y[idx]
- def __len__(self):
- return len(self.x)
- batch_size = 16
- train_dataset = IrisDataset(mode='train')
- dev_dataset = IrisDataset(mode='dev')
- test_dataset = IrisDataset(mode='test')
- train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
- dev_loader = DataLoader(dev_dataset, batch_size=batch_size)
- test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)
- # 遍历训练集
- for batch_x, batch_y in train_loader:
- # 训练模型
- pass
复制代码
## 总结
- Dataset定义数据源与访问方式,映射式最常用,流式数据用IterableDataset。
- DataLoader封装采样、批处理、多进程加载、内存固定等,参数丰富,理解每个参数可应对不同场景。
- 通过自定义sampler和collate_fn可灵活处理非平衡数据或异构样本。
- 多进程加载(num_workers>0)是加速训练的关键,需关注内存复制和系统兼容性。
掌握Dataset和DataLoader的内部机制,能构建高效数据管道,降低I/O瓶颈,充分释放GPU算力。 |