张量(Tensor)维度尺寸对不齐(Expected size xx but got size xx for tensor)

发布时间:2024年01月14日

本文以U-Net举例,演示如何解决张量(Tensor)维度尺寸对不齐的问题
U-Net的网络架构可以参考这篇文章:U-Net原理分析与代码解读
这是本文演示所用的U-Net代码

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        # 输入层
        self.input_conv = nn.Conv2d(3, 64, kernel_size=3, padding=1)

        # 下采样部分
        self.down1 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2)
        )
        self.down2 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2)
        )
        self.down3 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2)
        )
        self.down4 = nn.Sequential(
            nn.Conv2d(512, 1024, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2)
        )

        # 桥接层 - 将输出通道数修改为1024,以便与down4_out拼接时通道数一致
        self.bridge = nn.Sequential(
            nn.Conv2d(1024, 1024, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

        # 上采样部分 - 调整每个上采样的第一个卷积层输入通道数
        self.up1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(2048, 1024, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.up2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(1024, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.up3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.up4 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

        # 输出层
        self.final_conv = nn.Conv2d(64, NUM_CLASSES, kernel_size=1)

    def forward(self, x):
        x = self.input_conv(x)  # 对原始输入进行处理

        down1_out = self.down1(x)
        down2_out = self.down2(down1_out)
        down3_out = self.down3(down2_out)
        down4_out = self.down4(down3_out)

        bridge_out = self.bridge(down4_out)

        up1_out = self.up1(torch.cat([bridge_out, down4_out], dim=1))

        up2_out = self.up2(torch.cat([up1_out, down3_out], dim=1))
        up3_out = self.up3(torch.cat([up2_out, down2_out], dim=1))
        up4_out = self.up4(torch.cat([up3_out, down1_out], dim=1))

        final_out = self.final_conv(up4_out)
        return torch.sigmoid(final_out)  # 因为是二分类问题,所以输出通过sigmoid激活

假设本文输入的图像是600乘以400像素的尺寸,那么对于本文U-Net代码所需的512乘以512像素的输入是肯定不匹配的。

一、图像缩放
既然输入图像的尺寸与网络所需输入的尺寸不符合,那就将输入图像的尺寸缩放到符合网络所需输入的尺寸就可以了。
在预处理函数中直接对原始图像进行缩放。
本文举例U-Net的所需输入是512乘以512像素,所以直接缩放为512乘以512像素

# 定义预处理函数
def get_transforms():
    # 对于图像的transforms
    image_transforms_list = [
        transforms.Resize((512, 512)),  # 缩放至512x512
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 根据实际数据调整
    ]
    image_transform = transforms.Compose(image_transforms_list)

    # 对于mask的transforms(不需要归一化)
    mask_transforms_list = [
        transforms.Resize((512, 512)),  # 缩放至512x512
        transforms.ToTensor()
    ]
    mask_transform = transforms.Compose(mask_transforms_list)

    return image_transform, mask_transform

二、尺寸裁剪或尺寸填充
由于直接对原始图像进行缩放可能会对丢失一定的原始信息已经可能会扭曲一定的原始信息,所以更加建议使用尺寸裁剪或尺寸填充的方法。
尺寸裁剪或尺寸填充并不是在预处理函数中使用,而是在网络结构前向传播中使用,因为这时往往只需要改动几个像素点,对原始图像的改动较小。

定义尺寸裁剪或尺寸填充函数:
可以通过修改pad_value来决定用什么数值来填充(建议修改成背景的数值

def crop_or_pad_tensor(tensor, height_crop, width_crop, pad_value=0):
    '''
    裁剪或扩展Tensor在高度(仅底部)和宽度(仅右侧)维度上的最后一个像素。
    正数表示扩展(用0填充),负数表示裁剪。

    参数:
    tensor (torch.Tensor): 输入的4维张量,形状为 (batch_size, channels, height, width)
    height_crop (int): 高度方向上底部要裁剪或扩展的像素数量,默认为1
    width_crop (int): 宽度方向上右侧要裁剪或扩展的像素数量,默认为1
    pad_value (float or int): 填充时使用的值,默认为0

    返回:
    cropped_or_padded_tensor (torch.Tensor): 裁剪或扩展后的张量
    '''
    assert len(tensor.shape) == 4, '输入的tensor应为4维'

    # 获取原始的高度和宽度
    original_height, original_width = tensor.shape[2], tensor.shape[3]

    # 计算需要裁剪的数量(正值代表不裁剪,负值时代表裁剪)
    height_to_remove_from_bottom = min(original_height, -height_crop) if height_crop < 0 else 0
    width_to_remove_from_right = min(original_width, -width_crop) if width_crop < 0 else 0

    # 计算需要填充的数量(正值代表填充,负值代表不填充)
    pad_bottom = abs(height_crop) if height_crop > 0 else 0
    pad_right = abs(width_crop) if width_crop > 0 else 0

    # 先填充,再裁剪
    padded_tensor = F.pad(tensor, pad=(0, pad_right, 0, pad_bottom), mode='constant', value=pad_value)

    # 在高度和宽度维度上进行裁剪(如果需要)
    if height_to_remove_from_bottom > 0 and width_to_remove_from_right > 0:
        # 同时裁剪高度和宽度
        cropped_or_padded_tensor = padded_tensor[:, :, :-height_to_remove_from_bottom, :-width_to_remove_from_right]
    elif height_to_remove_from_bottom > 0:
        # 只裁剪高度
        cropped_or_padded_tensor = padded_tensor[:, :, :-height_to_remove_from_bottom, :]
    elif width_to_remove_from_right > 0:
        # 只裁剪宽度
        cropped_or_padded_tensor = padded_tensor[:, :, :, :-width_to_remove_from_right]
    else:
        # 不裁剪任何维度
        cropped_or_padded_tensor = padded_tensor

    return cropped_or_padded_tensor

在网络架构的forward方法中调用尺寸裁剪或尺寸填充函数:

# 对height裁剪一个像素,对width保持不变
crop_or_pad_tensor(up1_out, -1, 0)
# 对height保持不变,对width裁剪一个像素
crop_or_pad_tensor(up2_out, 0, -1)

在深度学习中,一个四维张量Tensor)通常代表的是批量图像数据,其维度排列通常是[batch_size, channels, height, width]
也就是Batch Size(批大小)Channels(通道数)Height(高度)、Width(宽度)
本文只讨论因Tensor中的heightwidth对不齐问题,batch_sizeChannels比较基础,就不提及了。

文章来源:https://blog.csdn.net/qq_45831414/article/details/135579672
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。