对于To Choose or to Fuse? Scale Selection for Crowd Counting(AAAI 2021)中的SAS的实现

发布时间:2023年12月22日
import torch
from torchvision.models import vgg16_bn as m

class densityhead(torch.nn.Module):
    def __init__(self,inc):
        super(densityhead,self).__init__()
        self.conv1_1=torch.nn.Conv2d(inc,inc//2,1,padding=0)
        self.conv1_2=torch.nn.Conv2d(inc//2,inc,1)

        self.conv2_1=torch.nn.Conv2d(inc,inc//2,1)
        self.conv2_2=torch.nn.Conv2d(inc//2,inc,3,padding=1)

        self.conv3_1 = torch.nn.Conv2d(inc, inc // 2, 1)
        self.conv3_2 = torch.nn.Conv2d(inc // 2, inc, 5, padding=2)

        self.finalconv=torch.nn.Conv2d(inc*4,1,1)

    def forward(self,x):
        x1=self.conv1_1(x)
        x1=self.conv1_2(x1)

        x2=self.conv2_1(x)
        x2=self.conv2_2(x2)

        x3=self.conv3_1(x)
        x3=self.conv3_2(x3)

        x_final=torch.cat([x3,x2,x1,x],1)
        y=self.finalconv(x_final)
        return y

class decoder(torch.nn.Module):

    def __init__(self,inc,out,mid):
        super(decoder,self).__init__()
        self.conv1 = torch.nn.Conv2d(inc, mid, 3, padding=1)
        self.conv2 = torch.nn.Conv2d(mid, out, 3, padding=1)
        self.relu = torch.nn.functional.relu

    def forward(self,x):
        #print('bf',x.shape)
        x = self.relu(self.conv1(x))
        #print('ok',x.shape)
        x = self.relu(self.conv2(x))
        return x


class confidencehead(torch.nn.Module):

    def __init__(self,inc,out=1):
        super(confidencehead,self).__init__()
        self.conv1=torch.nn.Conv2d(inc,inc//2,3,padding=1)
        self.conv2=torch.nn.Conv2d(inc//2,out,3,padding=1)
        self.relu=torch.nn.functional.relu

    def forward(self,x):
        x=self.relu(self.conv1(x))
        x=self.conv2(x)
        return x



class SAS(torch.nn.Module):

    def __init__(self):
        super(SAS,self).__init__()
        vgg16=m(pretrained=1)
        features=list(vgg16.features.children())
        self.body1=torch.nn.Sequential(*features[0:6])
        self.body2 = torch.nn.Sequential(*features[6:13])
        self.body3 = torch.nn.Sequential(*features[13:23])
        self.body4 = torch.nn.Sequential(*features[23:33])
        self.body5 = torch.nn.Sequential(*features[33:43])
        self.d5=decoder(512,512,1024)
        self.d4 = decoder(512*2, 256, 512)
        self.d3 = decoder(512, 128, 256)
        self.d2 = decoder(256, 64, 128)
        self.d1 = decoder(128, 64, 64)
        self.density_head5=densityhead(512)
        self.density_head4 = densityhead(256)
        self.density_head3 = densityhead(128)
        self.density_head2 = densityhead(64)
        self.density_head1 = densityhead(64)
        self.block_size=32
        self.confidence_head5=confidencehead(512)
        self.confidence_head4 = confidencehead(256)
        self.confidence_head3 = confidencehead(128)
        self.confidence_head2 = confidencehead(64)
        self.confidence_head1 = confidencehead(64)


    def forward(self,x):
        size = x.size()
        x1=self.body1(x)
        #print(x1.shape)
        x2 = self.body2(x1)
        #print(x2.shape)
        x3 = self.body3(x2)
        #print(x3.shape)
        x4 = self.body4(x3)
        #print(x4.shape)
        x5 = self.body5(x4)
      
        x = self.d5(x5)
       
        x5_out = x# P5

        x = torch.nn.functional.upsample_bilinear(x, size=x4.size()[2:])#上取样到指定的尺寸 也可以按照倍数来取
        x = torch.cat([x4, x], 1)
       
        x=self.d4(x)
        x4_out=x#P4

        x=torch.nn.functional.upsample_bilinear(x,size=x3.size()[2:])
        x=torch.cat([x3,x],1)
        x=self.d3(x)
        x3_out=x#P3

        x=torch.nn.functional.upsample_bilinear(x,size=x2.size()[2:])
        x=torch.cat([x2,x],1)
        x=self.d2(x)
        x2_out=x#P2
     

        x=torch.nn.functional.upsample_bilinear(x,size=x1.size()[2:])
        x=torch.cat([x1,x],1)
     
        x=self.d1(x)
        x1_out=x
       

        x5_density = self.density_head5(x5_out)
        x4_density = self.density_head4(x4_out)
        x3_density = self.density_head3(x3_out)
        x2_density = self.density_head2(x2_out)
        x1_density = self.density_head1(x1_out)
        # print(x5_density.shape,x4_density.shape)
        x5_confi = torch.nn.functional.adaptive_avg_pool2d(x5_out, output_size=(size[-2] // self.block_size, size[-1] // self.block_size))
        x4_confi = torch.nn.functional.adaptive_avg_pool2d(x4_out, output_size=(size[-2] // self.block_size, size[-1] // self.block_size))
        x3_confi = torch.nn.functional.adaptive_avg_pool2d(x3_out, output_size=(size[-2] // self.block_size, size[-1] // self.block_size))
        x2_confi = torch.nn.functional.adaptive_avg_pool2d(x2_out, output_size=(size[-2] // self.block_size, size[-1] // self.block_size))
        x1_confi = torch.nn.functional.adaptive_avg_pool2d(x1_out, output_size=(size[-2] // self.block_size, size[-1] // self.block_size))
        #对应论文中downsample到1/k原图的大小
        #print(x5_confi.shape,x4_confi.shape,x1_confi.shape)
        x5_confi = self.confidence_head5(x5_confi)
        x4_confi = self.confidence_head4(x4_confi)
        x3_confi = self.confidence_head3(x3_confi)
        x2_confi = self.confidence_head2(x2_confi)
        x1_confi = self.confidence_head1(x1_confi)

        F=torch.nn.functional
        x5_density = F.upsample_nearest(x5_density, size=x1.size()[2:])
        x4_density = F.upsample_nearest(x4_density, size=x1.size()[2:])
        x3_density = F.upsample_nearest(x3_density, size=x1.size()[2:])
        x2_density = F.upsample_nearest(x2_density, size=x1.size()[2:])
        x1_density = F.upsample_nearest(x1_density, size=x1.size()[2:])

        # 对置信图进行上采样
        x5_confi_upsample = F.upsample_nearest(x5_confi, size=x1.size()[2:])
        x4_confi_upsample = F.upsample_nearest(x4_confi, size=x1.size()[2:])
        x3_confi_upsample = F.upsample_nearest(x3_confi, size=x1.size()[2:])
        x2_confi_upsample = F.upsample_nearest(x2_confi, size=x1.size()[2:])
        x1_confi_upsample = F.upsample_nearest(x1_confi, size=x1.size()[2:])

        confidence_map = torch.cat([x5_confi_upsample, x4_confi_upsample,
                                    x3_confi_upsample, x2_confi_upsample, x1_confi_upsample], 1)
        confidence_map = torch.nn.functional.sigmoid(confidence_map)

        confidence_map = torch.nn.functional.softmax(confidence_map, 1)

        density_map = torch.cat([x5_density, x4_density, x3_density, x2_density, x1_density], 1)
        # soft selection
        density_map *= confidence_map
        #print(density_map.shape)
        density = torch.sum(density_map, 1, keepdim=True)

        return density



if __name__=='__main__':
    a=SAS()
    x=torch.ones((1,3,256,256))
    y=a(x)
    print(y.shape)

很遗憾的是作者没能实现论文中的损失函数
论文中各模块的通道数与官方版本略有出入

文章来源:https://blog.csdn.net/m0_49040755/article/details/135148004
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。