查看: 89|回复: 1

PyTorch Dataset与DataLoader实战:自定义数据集与参数配置详解

[复制链接]
发表于 2 小时前 | 显示全部楼层 |阅读模式
在PyTorch中,Dataset和DataLoader是构建数据管道的核心组件。Dataset封装数据源及访问方式,DataLoader负责批量加载、打乱、多进程并行等。本文将从自定义数据集实现、内置数据集加载、DataLoader参数详解到实战案例,帮助读者搭建高效的数据流水线。

## 自定义数据集:映射式与可迭代式

自定义Dataset必须继承torch.utils.data.Dataset并实现三个方法:

- __init__:初始化数据路径、标签、变换等;
- __len__:返回数据集样本总数;
- __getitem__(self, idx):根据索引返回样本(特征和标签)。

示例:
  1. from torch.utils.data import Dataset
  2. class MyDataset(Dataset):
  3.     def __init__(self, data, labels):
  4.         self.data = data
  5.         self.labels = labels
  6.     def __len__(self):
  7.         return len(self.data)
  8.     def __getitem__(self, idx):
  9.         return self.data[idx], self.labels[idx]
复制代码

PyTorch支持两种数据集风格:

- 映射式(Map-Style):实现__getitem__和__len__,通过索引随机访问,适用于数据能全部放入内存索引结构的场景,支持shuffle。
- 可迭代式(Iterable-Style):继承IterableDataset,实现__iter__返回迭代器,适用于数据流式读取或无法统计长度的场景,不能使用len()和shuffle。

## 内置数据集加载

torchvision.datasets提供多种内置数据集,如FashionMNIST。加载时需指定root路径、train标志、transform(如ToTensor)和download是否自动下载。
  1. import torch
  2. from torchvision import datasets
  3. from torchvision.transforms import ToTensor
  4. training_data = datasets.FashionMNIST(
  5.     root="./data",
  6.     train=True,
  7.     download=True,
  8.     transform=ToTensor()
  9. )
  10. test_data = datasets.FashionMNIST(
  11.     root="./data",
  12.     train=False,
  13.     download=True,
  14.     transform=ToTensor()
  15. )
复制代码

访问样本:training_data[sample_idx]返回(图像张量, 标签),标签映射为0~9对应的服装类别。

## DataLoader参数详解

DataLoader将DataSet封装为迭代器,核心API及参数含义:
  1. DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
  2.            batch_sampler=None, num_workers=0, collate_fn=None,
  3.            pin_memory=False, drop_last=False, timeout=0,
  4.            worker_init_fn=None, generator=None,
  5.            prefetch_factor=2, persistent_workers=False)
复制代码

- dataset:需要加载的Dataset对象。
- batch_size:每个批次样本数,默认1。
- shuffle:每个epoch开始时是否打乱数据(仅映射式有效,使用RandomSampler)。
- sampler:自定义索引采样器,若指定则shuffle必须为False。
- batch_sampler:返回批次索引,与batch_size、shuffle、sampler互斥。
- num_workers:数据加载子进程数,0表示主进程,>0可加速。
- collate_fn:将样本列表合并成批次,默认按第0维堆叠;样本结构不一致时需自定义。
- pin_memory:将张量复制到CUDA固定内存,加速CPU到GPU传输(仅CUDA)。
- drop_last:丢弃最后一个不完整批次,BatchNorm等需固定批次大小时建议True。
- timeout:获取批次超时时间(秒)。
- worker_init_fn:worker初始化函数,常用于设置随机种子。
- generator:伪随机数生成器,保证可复现。
- prefetch_factor:每个worker预加载的batch数(默认2)。
- persistent_workers:是否保持worker进程存活以加速后续epoch。

## 实战案例:鸢尾花数据集

使用sklearn的iris数据集,实现数据归一化、随机打乱,并划分训练/验证/测试集,通过DataLoader批量加载。
  1. import torch
  2. from sklearn.datasets import load_iris
  3. from torch.utils.data import Dataset, DataLoader
  4. def load_data(shuffle=True):
  5.     x = torch.tensor(load_iris().data)
  6.     y = torch.tensor(load_iris().target)
  7.     # 数据归一化
  8.     x_min = torch.min(x, dim=0).values
  9.     x_max = torch.max(x, dim=0).values
  10.     x = (x - x_min) / (x_max - x_min)
  11.     if shuffle:
  12.         idx = torch.randperm(x.shape[0])
  13.         x = x[idx]
  14.         y = y[idx]
  15.     return x, y
  16. class IrisDataset(Dataset):
  17.     def __init__(self, mode='train', num_train=120, num_dev=15):
  18.         super().__init__()
  19.         x, y = load_data(shuffle=True)
  20.         if mode == 'train':
  21.             self.x, self.y = x[:num_train], y[:num_train]
  22.         elif mode == 'dev':
  23.             self.x, self.y = x[num_train:num_train+num_dev], y[num_train:num_train+num_dev]
  24.         else:
  25.             self.x, self.y = x[num_train+num_dev:], y[num_train+num_dev:]
  26.     def __getitem__(self, idx):
  27.         return self.x[idx], self.y[idx]
  28.     def __len__(self):
  29.         return len(self.x)
  30. batch_size = 16
  31. train_dataset = IrisDataset(mode='train')
  32. dev_dataset = IrisDataset(mode='dev')
  33. test_dataset = IrisDataset(mode='test')
  34. train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
  35. dev_loader = DataLoader(dev_dataset, batch_size=batch_size)
  36. test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)
  37. # 遍历训练集
  38. for batch_x, batch_y in train_loader:
  39.     # 训练模型
  40.     pass
复制代码

## 总结

- Dataset定义数据源与访问方式,映射式最常用,流式数据用IterableDataset。
- DataLoader封装采样、批处理、多进程加载、内存固定等,参数丰富,理解每个参数可应对不同场景。
- 通过自定义sampler和collate_fn可灵活处理非平衡数据或异构样本。
- 多进程加载(num_workers>0)是加速训练的关键,需关注内存复制和系统兼容性。

掌握Dataset和DataLoader的内部机制,能构建高效数据管道,降低I/O瓶颈,充分释放GPU算力。
回复

使用道具 举报

发表于 1 小时前 | 显示全部楼层

Re: PyTorch Dataset与DataLoader实战:自定义数据集与参数配置详解

干货满满的帖!自定义数据集的映射式和可迭代式区别讲得很清楚,DataLoader参数列表也很有参考价值。 不过实战案例里有个小笔误:`load_data` 函数里生成 `idx` 后,`x = x; y = y` 并没有真正打乱,应该改成 `x = x; y = y`?另外最后一行 `if mode == 'trai` 似乎被截断了,方便的话补全一下?期待后续完善的内容~
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 注册

本版积分规则

指导单位

江苏省公安厅

江苏省通信管理局

浙江省台州刑侦支队

DEFCON GROUP 86025

Hacking Group 021A

旗下站点

态势感知中心

应急响应中心

红盟安全

联系我们

官方QQ群:112851260

官方邮箱:security#ihonker.org(#改成@)

官方核心成员

关注微信公众号

Archiver|手机版|小黑屋| ( 沪ICP备2021026908号 )

GMT+8, 2026-6-12 13:06 , Processed in 0.028574 second(s), 17 queries , Gzip On, Redis On.

Powered by ihonker.com

Copyright © 2015-现在.

  • 返回顶部