1、view ;reshape;Flatten:维度合并和分解
2、squeeze;unsqueeze:压缩维度和增加维度(相对于维度为1的数据)
3、transpose;t;permute:维度顺顺序变换(转置)
4、expand;repeat:维度扩展
import torch
'''
维度变换
1、view ;reshape;Flatten:维度合并和分解
2、squeeze;unsqueeze:压缩维度和增加维度(相对于维度为1的数据)
3、transpose;t;permute:维度顺顺序变换(转置)
4、expand;repeat:维度扩展
'''
a = torch.rand(4, 1, 32, 32)
'''
view()的原理很简单,其实就是把原先tensor中的数据进行排列,排成一行,然后根据所给的view()中的参数从一行中按顺序选择组成最终的tensor。
view()可以有多个参数,这取决于你想要得到的是几维的tensor,一般设置两个参数,也是神经网络中常用的(一般在全连接之前),代表二维。
view(h,w),h代表行(想要变为几行),当不知道要变为几行,但知道要变为几列时可取-1;w代表的是列(想要变为几列),当不知道要变为几列,但知道要变为几行时可取-1。
'''
def zqb_view():
print(a.shape) # torch.Size([4, 1, 32, 32])
a1 = a.view(4, 32 * 32)
print(a1.shape) # torch.Size([4, 1024])
a2 = a1.view(4, 1, 32, 32)
print(a2.shape) # torch.Size([4, 1, 32, 32])
# a3 = a1.view(4,28,28) #RuntimeError: shape '[4, 28, 28]' is invalid for input of size 4096
# 要保持输出数据与输入数据总量,防止数据污染
# a4 = a1.view(4,32,32,1) # 逻辑错误,改变了原来数据的存储方式,虽然不会报错,但是数据已经被污染,无法正常使用
a5 = a.view(-1, 32 * 32) # torch.Size([4, 1024]) -1表示该维度保持不变
print