在Windows系统上使用PyTorch的DataLoader加载数据时,如果设置num_workers大于0,容易遇到RuntimeError: DataLoader worker (pid(s) ...) exited unexpectedly的报错。本文将结合具体案例分析错误原因,并给出两种可行的修复方案。
环境与问题复现
操作系统:Windows 10
PyTorch版本:1.5.1+cu101
以下代码来自莫烦Python教程,用于演示mini-batch训练:
- import torch
- import torch.utils.data as Data
- print(torch.__version__)
- BATCH_SIZE = 5
- x = torch.linspace(1, 10, 10) # [1,2,3,4,5,6,7,8,9,10]
- y = torch.linspace(10, 1, 10) # [10,9,8,7,6,5,4,3,2,1]
- torch_dataset = Data.TensorDataset(x, y)
- loader = Data.DataLoader(
- dataset=torch_dataset,
- batch_size=BATCH_SIZE,
- shuffle=True,
- num_workers=2,
- )
- for epoch in range(3):
- for step, (batch_x, batch_y) in enumerate(loader):
- print('Epoch:', epoch, '|Step', step, '|batch x:', batch_x.numpy(), '|batch y:', batch_y.numpy())
复制代码
直接运行以上脚本,会抛出如下异常:
- RuntimeError:
- An attempt has been made to start a new process before the
- current process has finished its bootstrapping phase.
- This probably means that you are not using fork to start your
- child processes and you have forgotten to use the proper idiom
- in the main module:
- if __name__ == '__main__':
- freeze_support()
- ...
- ...
- RuntimeError: DataLoader worker (pid(s) 8528, 8488) exited unexpectedly
复制代码
错误定位与分析
从堆栈信息可以确定,异常发生在迭代loader的for循环语句处:- for step, (batch_x, batch_y) in enumerate(loader):
复制代码
关键线索有两个:
1. 错误提示DataLoader worker的pid(s)异常,表明num_workers=2的设置可能有问题。
2. 错误信息明确建议使用标准写法:在main模块中先调用freeze_support(),并提醒Windows下需遵守此惯用法。
结合这两点可以判断,该错误源自Windows环境下使用多进程(num_workers > 0)时,没有以正确方式启动子进程。在Linux或类Unix系统上,使用fork方式启动子进程通常可以直接运行;但Windows不支持fork,PyTorch DataLoader在Windows上使用spawn方式创建进程,因此要求主模块的代码必须被保护在if __name__ == '__main__'块内,并在其中调用freeze_support()(如果程序会被打包成exe则需要,普通脚本可省略该行)。
解决方案
方案一:将num_workers设为0,完全禁用多进程加载。
- loader = Data.DataLoader(
- dataset=torch_dataset,
- batch_size=BATCH_SIZE,
- shuffle=True,
- num_workers=0,
- )
复制代码
修改后运行成功。此方案简单但会失去并行数据加载带来的性能提升,适用于数据量小或快速调试的场景。
方案二:采用标准的多进程写法,将迭代代码放入if __name__ == '__main__'块中。
- if __name__ == '__main__':
- for epoch in range(3):
- for step, (batch_x, batch_y) in enumerate(loader):
- print('Epoch:', epoch, '|Step', step, '|batch x:', batch_x.numpy(), '|batch y:', batch_y.numpy())
复制代码
如果有freeze_support()的需求(如将脚本打包成exe),可以在main块开始处添加freeze_support(),但普通脚本中可省略。修改后num_workers=2即可正常工作。
原理说明
在Windows上,PyTorch DataLoader使用multiprocessing模块创建worker子进程,这些子进程会从头重新导入主模块。如果不加if __name__ == '__main__'保护,子进程在导入主模块时也会尝试执行创建DataLoader的代码,从而导致递归创建进程,引发错误。加上保护后,只有主进程会执行DataLoader初始化及迭代逻辑,子进程只负责加载数据。
总结
在Windows环境下使用PyTorch DataLoader且num_workers > 0时,必须将数据迭代代码包裹在if __name__ == '__main__'内。若不需要并行加载,可将num_workers设为0以快速解决问题。该规则同样适用于其他需要多进程支持的PyTorch功能(如torch.multiprocessing)。建议在编写训练脚本时始终采用标准主模块保护写法,以避免跨平台兼容性问题。 |