在PyTorch中,unfold是一个用于滑动窗口数据提取的操作,它不执行卷积或池化等计算,只负责将张量沿指定维度切分成多个局部区域。理解unfold的原理对于实现Vision Transformer中的patch embedding、滑动窗口特征提取等任务至关重要。
一、unfold基本格式与参数
调用方式:x.unfold(dimension, size, step)
- dimension:切分的维度索引
- size:窗口大小
- step:每次滑动的步长
返回的新张量中,原维度位置变为窗口数量维度,末尾新增一个维度表示窗口内元素个数。
二、一维张量例子- import torch
- x = torch.tensor([1, 2, 3, 4, 5, 6])
- y = x.unfold(0, 3, 1)
- print(y)
- # tensor([[1, 2, 3],
- # [2, 3, 4],
- # [3, 4, 5],
- # [4, 5, 6]])
- print(y.shape) # [4, 3]
复制代码 原长度6被拆成4个窗口,每个窗口长度3。
三、size与step的关系
1. step < size:窗口重叠。例如unfold(0,3,1)得到4个窗口,相邻窗口有2个元素重叠,常用于滑动窗口特征提取。
2. step == size:窗口不重叠。例如unfold(0,3,3)得到[[1,2,3],[4,5,6]],图像切patch时通常采用此方式。
3. step > size:窗口间有空隙,部分数据被跳过。例如unfold(0,2,3)得到[[1,2],[4,5]],中间的元素3未被使用。
四、窗口数量计算公式
假设原始长度为L,窗口大小size,步长step,则窗口数量为 floor((L - size) / step) + 1。注意unfold不会自动补边,剩余数据不足一个完整窗口时会被舍弃。例如L=6,size=3,step=2,窗口数为floor((6-3)/2)+1=2。
五、unfold为什么多出一维
unfold将一个维度拆解为两个维度:窗口数量维度和窗口内部大小维度。窗口数量维度保留在原位置,窗口内部大小维度追加到最后。例如原始维度长度为6,执行unfold(0,3,1)后,原维度变成4(窗口数),末尾新增维度3(窗口大小)。
六、图像中的unfold:两次切分实现patch提取
图像张量通常为[B, C, H, W]。为了沿高度和宽度方向分别用unfold切分,需要先通过permute将通道维度移到末尾:[B, H, W, C]。- x = x.permute(0, 2, 3, 1) # [B, H, W, C]
复制代码 第一次unfold沿高度方向(第1维):- x = x.unfold(1, patch_size, patch_size) # [B, H_num, W, C, patch_H]
复制代码 第二次unfold沿宽度方向(第2维):- x = x.unfold(2, patch_size, patch_size) # [B, H_num, W_num, C, patch_H, patch_W]
复制代码 最后通过contiguous和view整理成标准patch格式:- x = x.contiguous().view(B, -1, C, patch_size, patch_size) # [B, N, C, patch_size, patch_size]
复制代码 其中N = H_num × W_num,即每张图片的patch总数。
七、实战例子:32×32图像切16×16 patch
输入形状[1, 3, 32, 32],patch_size=16。
- permute后:[1, 32, 32, 3]
- 第一次unfold(高度方向):[1, 2, 32, 3, 16](32/16=2)
- 第二次unfold(宽度方向):[1, 2, 2, 3, 16, 16](32/16=2)
- view后:[1, 4, 3, 16, 16](N=2×2=4)
八、实战例子:224×224图像切16×16 patch
输入形状[B, 3, 224, 224],patch_size=16。
高度方向窗口数:224/16=14;宽度方向窗口数:224/16=14。
最终输出形状:[B, 196, 3, 16, 16],每张图片得到196个patch。
通过掌握unfold的参数含义和维度变换规律,开发者可以灵活实现各种滑动窗口操作,在图像分割、Patch Embedding、局部特征提取等场景中高效编码。 |