3x3的local attention的pytorch实现

发布时间:2024年01月06日

用以代替卷积

复杂度为(O(window_size * o + c * o))。3x3卷积复杂度为O(window_size * c * o)。

参数量上是卷积的1/9。

效果需要后续详细测试。

因为实现的方法利用的是pytorch现成的接口,因此运算和内存占用都比较高。但理论上运算和内存占用应该都比较小的。后续可以实现基于cuda编程的版本。

class transformConv(nn.Module):
    def __init__(self, in_channels, out_channels, stride = 1):
        super(transformConv, self).__init__()
        self.stride = stride
        self.out_channels = stride
        self.pad = 1
        #计算query key和value,query key和value使用同一个向量
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 1, stride = 1, padding = 0, bias = True)
        # 创建一个 shape 为 (9, 1, 1, 1) 的 随机posEmbedding,来源于GPT2
        self.posEmbedding = torch.empty((9, 1, 1, 1), device = "cpu")
        # 初始化 posEmbedding
        nn.init.normal_(self.posEmbedding, std = 0.02)
        if torch.cuda.is_available():
           # 选择第0号GPU作为目标设备
           device = torch.device('cuda')
           # 将张量移动到指定的GPU上
           self.posEmbedding = self.posEmbedding.to(device)
           print("移动后的张量 y (位于GPU):\n", self.posEmbedding.is_cuda)
        else:
           print("没有可用的GPU\n")
        
    def forward(self, x):
        # 将x进行unfold,进行img2col
        k = 3
        j = 3
        in_c = x.shape[1]
        # 生成query, key和value
        x = self.conv1(x)#(n,c,h,w)
        x_pad = F.pad(x, pad=(1, 1, 1, 1, 0, 0), mode = "replicate")
        x_pad = x_pad.unfold(2, k, self.stride)
        x_pad = x_pad.unfold(3, j, self.stride)        # 按照滑动窗展开 nchwkj
        x_pad = x_pad.permute(0, 4, 5, 1, 2, 3) #[n,3,3,c,h,w]
        x_pad = x_pad.contiguous().view([x_pad.shape[0], -1, x_pad.shape[3], x_pad.shape[4], x_pad.shape[5]]) #[n,9,c,h,w]
        
        # 加上位置信息,位置信息也要执行相同卷积
        tiled_posEmbedding = torch.tile(self.posEmbedding, (1, in_c, 1, 1)) #(9,1,1,1)->(9,c_in,1,1)
        tiled_posEmbedding_conv = self.conv1(tiled_posEmbedding)#(9,c,1,1)
        # tile成和x_pad相同的shape
        tiled_posEmbedding_conv = tiled_posEmbedding_conv.view(1, tiled_posEmbedding_conv.shape[0], tiled_posEmbedding_conv.shape[1], tiled_posEmbedding_conv.shape[2], tiled_posEmbedding_conv.shape[3])#(1,9,c,1,1)
        tiled_posEmbedding_conv = torch.tile(tiled_posEmbedding_conv, (x_pad.shape[0], 1, 1, x_pad.shape[3], x_pad.shape[4])) #(1,9,c,1,1)->(n,9,c,h,w)

        # 进行local attention计算
        x_pad = x_pad.add_(tiled_posEmbedding_conv)#[n,9,c,h,w]
        x_center = x_pad[:, 3:4, :, :, :]
        out = torch.nn.functional.softmax(x_center * x_pad, dim = 1).mul(x_pad).sum(dim = 1, keepdim = False)#[n,c,h,w]

        return out

将其替换进resnet

class BasicBlock(nn.Module):  
    def __init__(self, in_channels, out_channels, stride = 1):
        super(BasicBlock, self).__init__()
        self.stride = stride
        #self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv1 = transformConv(in_channels, out_channels, stride = stride)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace = True)
        #self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2 = transformConv(out_channels, out_channels, stride = 1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.extra = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size = 1, stride = stride),
            nn.BatchNorm2d(out_channels))
    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += self.extra(x)
        out = self.relu(out)

        return out

class ResNet(nn.Module):
    def __init__(self, block, layers): 
        super(ResNet, self).__init__()
        self.in_channels = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size = 7, stride = 2, padding = 3, bias = False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace = True)
        self.conv2 = transformConv(64, 64)
        self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)

        self.layer1 = self._make_layer(block, 64, layers[0], stride = 1) # h/4
        self.layer2 = self._make_layer(block, 128, layers[1], stride = 2) # h/8
        self.layer3 = self._make_layer(block, 256, layers[2], stride = 2) # h/16
        #self.layer4 = self._make_layer(block, 512, layers[3], stride=2) # h/32

        #self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        #self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_channels, blocks, stride = 1):
        layers = []
        layers.append(block(self.in_channels, out_channels, stride))
        self.in_channels = out_channels
        for _ in range(1, blocks):
            layers.append(block(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.maxpool(x)

        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)

        #x = self.avgpool(x)
        #x = x.view(x.size(0), -1)
        #x = self.fc(x)

        return [x3, x2, x1]

class FPN(nn.Module):
    def __init__(self, resnet):
        super(FPN, self).__init__()
        self.resnet = resnet
        self.deconv4_2 = nn.ConvTranspose2d(
    		in_channels=128, out_channels = 128, kernel_size=(1, 1), stride = 1, padding = 0)# h/16
        self.deconv4_1 = nn.ConvTranspose2d(
    		in_channels=256, out_channels = 128, kernel_size=(1, 1), stride = 1, padding = 0)# h/16
        self.bn4 = nn.BatchNorm2d(128)
        self.relu4 = nn.ReLU()
        self.deconv5_2 = nn.ConvTranspose2d(
    		in_channels=64, out_channels = 64, kernel_size=(1, 1), stride = 1, padding = 0)# h/8
        self.deconv5_1 = nn.ConvTranspose2d(
    		in_channels=128, out_channels = 64, kernel_size=(1, 1), stride = 1, padding = 0)# h/8
        self.bn5 = nn.BatchNorm2d(64)
        self.relu5 = nn.ReLU()

        self.deconv7 = nn.ConvTranspose2d(
    		in_channels=128, out_channels = 2, kernel_size = (1, 1), stride = 1, padding = 0)#h

    def forward(self, x):
        left = self.resnet(x[:, 0:3, :, :])
        # 从最高层开始,逐层生成特征图
        features = left

        # 将左图特征图向上采样,并与下层特征图融合
        upsample = F.interpolate(features[0], size=(features[1].shape[2], features[1].shape[3]), mode='bilinear') 
        features[1] = features[1] + self.deconv4_1(upsample)
        features[1] = self.deconv4_2(features[1])
        features[1] = self.relu4(self.bn4(features[1]))

        upsample = F.interpolate(features[1], size=(features[2].shape[2], features[2].shape[3]), mode='bilinear')
        features[2] = features[2] + self.deconv5_1(upsample)
        features[2] = self.deconv5_2(features[2])
        features[2] = self.relu5(self.bn5(features[2]))
        upsample = features[2]

        right = self.resnet(x[:, 3:6, :, :])
        # 从最高层开始,逐层生成特征图
        features = right
        # 将右图特征图向上采样,并与下层特征图融合
        upsample2 = F.interpolate(features[0], size=(features[1].shape[2], features[1].shape[3]), mode='bilinear') 
        features[1] = features[1] + self.deconv4_1(upsample2)
        features[1] = self.deconv4_2(features[1])
        features[1] = self.relu4(self.bn4(features[1]))

        upsample2 = F.interpolate(features[1], size=(features[2].shape[2], features[2].shape[3]), mode='bilinear')
        features[2] = features[2] + self.deconv5_1(upsample2)
        features[2] = self.deconv5_2(features[2])
        features[2] = self.relu5(self.bn5(features[2]))
        upsample2 = features[2]

        #对左图和右图进行特征拼接,执行correlation操作,直接预测光流
        concat_feature = torch.cat([upsample, upsample2], dim=1)
        concat_feature = F.interpolate(concat_feature, size=(x.shape[2], x.shape[3]), mode='bilinear')#上采样4倍
        flow = self.deconv7(concat_feature)
        return [flow,]

class FlowNetC(nn.Module):
    expansion = 1

    def __init__(self,batchNorm = True):
        super(FlowNetC,self).__init__()
        self.resnet = ResNet(BasicBlock, [2, 2, 2, 2])
        self.fpn = FPN(self.resnet)
    def forward(self, x):
        if self.training:
            ret = self.fpn(x)
            return ret
        else:
            ret = self.fpn(x)[0]
            return ret

    def weight_parameters(self):
        return [param for name, param in self.named_parameters() if 'weight' in name]

    def bias_parameters(self):
        return [param for name, param in self.named_parameters() if 'bias' in name]


def flownetc(data=None):
    """FlowNetS model architecture from the
    "Learning Optical Flow with Convolutional Networks" paper (https://arxiv.org/abs/1504.06852)

    Args:
        data : pretrained weights of the network. will create a new one if not set
    """
    model = FlowNetC(batchNorm = False)
    if data is not None:
        model.load_state_dict(data['state_dict'])
    return model

文章来源:https://blog.csdn.net/LYF1993/article/details/135431427
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。