import torch
from torchvision.models import vgg16_bn as m
class densityhead(torch.nn.Module):
def __init__(self,inc):
super(densityhead,self).__init__()
self.conv1_1=torch.nn.Conv2d(inc,inc//2,1,padding=0)
self.conv1_2=torch.nn.Conv2d(inc//2,inc,1)
self.conv2_1=torch.nn.Conv2d(inc,inc//2,1)
self.conv2_2=torch.nn.Conv2d(inc//2,inc,3,padding=1)
self.conv3_1 = torch.nn.Conv2d(inc, inc // 2, 1)
self.conv3_2 = torch.nn.Conv2d(inc // 2, inc, 5, padding=2)
self.finalconv=torch.nn.Conv2d(inc*4,1,1)
def forward(self,x):
x1=self.conv1_1(x)
x1=self.conv1_2(x1)
x2=self.conv2_1(x)
x2=self.conv2_2(x2)
x3=self.conv3_1(x)
x3=self.conv3_2(x3)
x_final=torch.cat([x3,x2,x1,x],1)
y=self.finalconv(x_final)
return y
class decoder(torch.nn.Module):
def __init__(self,inc,out,mid):
super(decoder,self).__init__()
self.conv1 = torch.nn.Conv2d(inc, mid, 3, padding=1)
self.conv2 = torch.nn.Conv2d(mid, out, 3, padding=1)
self.relu = torch.nn.functional.relu
def forward(self,x):
#print('bf',x.shape)
x = self.relu(self.conv1(x))
#print('ok',x.shape)
x = self.relu(self.conv2(x))
return x
class confidencehead(torch.nn.Module):
def __init__(self,inc,out=1):
super(confidencehead,self).__init__()
self.conv1=torch.nn.Conv2d(inc,inc//2,3,padding=1)
self.conv2=torch.nn.Conv2d(inc//2,out,3,padding=1)
self.relu=torch.nn.functional.relu
def forward(self,x):
x=self.relu(self.conv1(x))
x=self.conv2(x)
return x
class SAS(torch.nn.Module):
def __init__(self):
super(SAS,self).__init__()
vgg16=m(pretrained=1)
features=list(vgg16.features.children())
self.body1=torch.nn.Sequential(*features[0:6])
self.body2 = torch.nn.Sequential(*features[6:13])
self.body3 = torch.nn.Sequential(*features[13:23])
self.body4 = torch.nn.Sequential(*features[23:33])
self.body5 = torch.nn.Sequential(*features[33:43])
self.d5=decoder(512,512,1024)
self.d4 = decoder(512*2, 256, 512)
self.d3 = decoder(512, 128, 256)
self.d2 = decoder(256, 64, 128)
self.d1 = decoder(128, 64, 64)
self.density_head5=densityhead(512)
self.density_head4 = densityhead(256)
self.density_head3 = densityhead(128)
self.density_head2 = densityhead(64)
self.density_head1 = densityhead(64)
self.block_size=32
self.confidence_head5=confidencehead(512)
self.confidence_head4 = confidencehead(256)
self.confidence_head3 = confidencehead(128)
self.confidence_head2 = confidencehead(64)
self.confidence_head1 = confidencehead(64)
def forward(self,x):
size = x.size()
x1=self.body1(x)
#print(x1.shape)
x2 = self.body2(x1)
#print(x2.shape)
x3 = self.body3(x2)
#print(x3.shape)
x4 = self.body4(x3)
#print(x4.shape)
x5 = self.body5(x4)
x = self.d5(x5)
x5_out = x# P5
x = torch.nn.functional.upsample_bilinear(x, size=x4.size()[2:])#上取样到指定的尺寸 也可以按照倍数来取
x = torch.cat([x4, x], 1)
x=self.d4(x)
x4_out=x#P4
x=torch.nn.functional.upsample_bilinear(x,size=x3.size()[2:])
x=torch.cat([x3,x],1)
x=self.d3(x)
x3_out=x#P3
x=torch.nn.functional.upsample_bilinear(x,size=x2.size()[2:])
x=torch.cat([x2,x],1)
x=self.d2(x)
x2_out=x#P2
x=torch.nn.functional.upsample_bilinear(x,size=x1.size()[2:])
x=torch.cat([x1,x],1)
x=self.d1(x)
x1_out=x
x5_density = self.density_head5(x5_out)
x4_density = self.density_head4(x4_out)
x3_density = self.density_head3(x3_out)
x2_density = self.density_head2(x2_out)
x1_density = self.density_head1(x1_out)
# print(x5_density.shape,x4_density.shape)
x5_confi = torch.nn.functional.adaptive_avg_pool2d(x5_out, output_size=(size[-2] // self.block_size, size[-1] // self.block_size))
x4_confi = torch.nn.functional.adaptive_avg_pool2d(x4_out, output_size=(size[-2] // self.block_size, size[-1] // self.block_size))
x3_confi = torch.nn.functional.adaptive_avg_pool2d(x3_out, output_size=(size[-2] // self.block_size, size[-1] // self.block_size))
x2_confi = torch.nn.functional.adaptive_avg_pool2d(x2_out, output_size=(size[-2] // self.block_size, size[-1] // self.block_size))
x1_confi = torch.nn.functional.adaptive_avg_pool2d(x1_out, output_size=(size[-2] // self.block_size, size[-1] // self.block_size))
#对应论文中downsample到1/k原图的大小
#print(x5_confi.shape,x4_confi.shape,x1_confi.shape)
x5_confi = self.confidence_head5(x5_confi)
x4_confi = self.confidence_head4(x4_confi)
x3_confi = self.confidence_head3(x3_confi)
x2_confi = self.confidence_head2(x2_confi)
x1_confi = self.confidence_head1(x1_confi)
F=torch.nn.functional
x5_density = F.upsample_nearest(x5_density, size=x1.size()[2:])
x4_density = F.upsample_nearest(x4_density, size=x1.size()[2:])
x3_density = F.upsample_nearest(x3_density, size=x1.size()[2:])
x2_density = F.upsample_nearest(x2_density, size=x1.size()[2:])
x1_density = F.upsample_nearest(x1_density, size=x1.size()[2:])
# 对置信图进行上采样
x5_confi_upsample = F.upsample_nearest(x5_confi, size=x1.size()[2:])
x4_confi_upsample = F.upsample_nearest(x4_confi, size=x1.size()[2:])
x3_confi_upsample = F.upsample_nearest(x3_confi, size=x1.size()[2:])
x2_confi_upsample = F.upsample_nearest(x2_confi, size=x1.size()[2:])
x1_confi_upsample = F.upsample_nearest(x1_confi, size=x1.size()[2:])
confidence_map = torch.cat([x5_confi_upsample, x4_confi_upsample,
x3_confi_upsample, x2_confi_upsample, x1_confi_upsample], 1)
confidence_map = torch.nn.functional.sigmoid(confidence_map)
confidence_map = torch.nn.functional.softmax(confidence_map, 1)
density_map = torch.cat([x5_density, x4_density, x3_density, x2_density, x1_density], 1)
# soft selection
density_map *= confidence_map
#print(density_map.shape)
density = torch.sum(density_map, 1, keepdim=True)
return density
if __name__=='__main__':
a=SAS()
x=torch.ones((1,3,256,256))
y=a(x)
print(y.shape)
很遗憾的是作者没能实现论文中的损失函数
论文中各模块的通道数与官方版本略有出入