用以代替卷积
复杂度为(O(window_size * o + c * o))。3x3卷积复杂度为O(window_size * c * o)。
参数量上是卷积的1/9。
效果需要后续详细测试。
因为实现的方法利用的是pytorch现成的接口,因此运算和内存占用都比较高。但理论上运算和内存占用应该都比较小的。后续可以实现基于cuda编程的版本。
class transformConv(nn.Module):
def __init__(self, in_channels, out_channels, stride = 1):
super(transformConv, self).__init__()
self.stride = stride
self.out_channels = stride
self.pad = 1
#计算query key和value,query key和value使用同一个向量
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 1, stride = 1, padding = 0, bias = True)
# 创建一个 shape 为 (9, 1, 1, 1) 的 随机posEmbedding,来源于GPT2
self.posEmbedding = torch.empty((9, 1, 1, 1), device = "cpu")
# 初始化 posEmbedding
nn.init.normal_(self.posEmbedding, std = 0.02)
if torch.cuda.is_available():
# 选择第0号GPU作为目标设备
device = torch.device('cuda')
# 将张量移动到指定的GPU上
self.posEmbedding = self.posEmbedding.to(device)
print("移动后的张量 y (位于GPU):\n", self.posEmbedding.is_cuda)
else:
print("没有可用的GPU\n")
def forward(self, x):
# 将x进行unfold,进行img2col
k = 3
j = 3
in_c = x.shape[1]
# 生成query, key和value
x = self.conv1(x)#(n,c,h,w)
x_pad = F.pad(x, pad=(1, 1, 1, 1, 0, 0), mode = "replicate")
x_pad = x_pad.unfold(2, k, self.stride)
x_pad = x_pad.unfold(3, j, self.stride) # 按照滑动窗展开 nchwkj
x_pad = x_pad.permute(0, 4, 5, 1, 2, 3) #[n,3,3,c,h,w]
x_pad = x_pad.contiguous().view([x_pad.shape[0], -1, x_pad.shape[3], x_pad.shape[4], x_pad.shape[5]]) #[n,9,c,h,w]
# 加上位置信息,位置信息也要执行相同卷积
tiled_posEmbedding = torch.tile(self.posEmbedding, (1, in_c, 1, 1)) #(9,1,1,1)->(9,c_in,1,1)
tiled_posEmbedding_conv = self.conv1(tiled_posEmbedding)#(9,c,1,1)
# tile成和x_pad相同的shape
tiled_posEmbedding_conv = tiled_posEmbedding_conv.view(1, tiled_posEmbedding_conv.shape[0], tiled_posEmbedding_conv.shape[1], tiled_posEmbedding_conv.shape[2], tiled_posEmbedding_conv.shape[3])#(1,9,c,1,1)
tiled_posEmbedding_conv = torch.tile(tiled_posEmbedding_conv, (x_pad.shape[0], 1, 1, x_pad.shape[3], x_pad.shape[4])) #(1,9,c,1,1)->(n,9,c,h,w)
# 进行local attention计算
x_pad = x_pad.add_(tiled_posEmbedding_conv)#[n,9,c,h,w]
x_center = x_pad[:, 3:4, :, :, :]
out = torch.nn.functional.softmax(x_center * x_pad, dim = 1).mul(x_pad).sum(dim = 1, keepdim = False)#[n,c,h,w]
return out
将其替换进resnet
class BasicBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride = 1):
super(BasicBlock, self).__init__()
self.stride = stride
#self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.conv1 = transformConv(in_channels, out_channels, stride = stride)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace = True)
#self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.conv2 = transformConv(out_channels, out_channels, stride = 1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.extra = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size = 1, stride = stride),
nn.BatchNorm2d(out_channels))
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += self.extra(x)
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers):
super(ResNet, self).__init__()
self.in_channels = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size = 7, stride = 2, padding = 3, bias = False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace = True)
self.conv2 = transformConv(64, 64)
self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
self.layer1 = self._make_layer(block, 64, layers[0], stride = 1) # h/4
self.layer2 = self._make_layer(block, 128, layers[1], stride = 2) # h/8
self.layer3 = self._make_layer(block, 256, layers[2], stride = 2) # h/16
#self.layer4 = self._make_layer(block, 512, layers[3], stride=2) # h/32
#self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
#self.fc = nn.Linear(512 * block.expansion, num_classes)
def _make_layer(self, block, out_channels, blocks, stride = 1):
layers = []
layers.append(block(self.in_channels, out_channels, stride))
self.in_channels = out_channels
for _ in range(1, blocks):
layers.append(block(out_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.maxpool(x)
x1 = self.layer1(x)
x2 = self.layer2(x1)
x3 = self.layer3(x2)
#x = self.avgpool(x)
#x = x.view(x.size(0), -1)
#x = self.fc(x)
return [x3, x2, x1]
class FPN(nn.Module):
def __init__(self, resnet):
super(FPN, self).__init__()
self.resnet = resnet
self.deconv4_2 = nn.ConvTranspose2d(
in_channels=128, out_channels = 128, kernel_size=(1, 1), stride = 1, padding = 0)# h/16
self.deconv4_1 = nn.ConvTranspose2d(
in_channels=256, out_channels = 128, kernel_size=(1, 1), stride = 1, padding = 0)# h/16
self.bn4 = nn.BatchNorm2d(128)
self.relu4 = nn.ReLU()
self.deconv5_2 = nn.ConvTranspose2d(
in_channels=64, out_channels = 64, kernel_size=(1, 1), stride = 1, padding = 0)# h/8
self.deconv5_1 = nn.ConvTranspose2d(
in_channels=128, out_channels = 64, kernel_size=(1, 1), stride = 1, padding = 0)# h/8
self.bn5 = nn.BatchNorm2d(64)
self.relu5 = nn.ReLU()
self.deconv7 = nn.ConvTranspose2d(
in_channels=128, out_channels = 2, kernel_size = (1, 1), stride = 1, padding = 0)#h
def forward(self, x):
left = self.resnet(x[:, 0:3, :, :])
# 从最高层开始,逐层生成特征图
features = left
# 将左图特征图向上采样,并与下层特征图融合
upsample = F.interpolate(features[0], size=(features[1].shape[2], features[1].shape[3]), mode='bilinear')
features[1] = features[1] + self.deconv4_1(upsample)
features[1] = self.deconv4_2(features[1])
features[1] = self.relu4(self.bn4(features[1]))
upsample = F.interpolate(features[1], size=(features[2].shape[2], features[2].shape[3]), mode='bilinear')
features[2] = features[2] + self.deconv5_1(upsample)
features[2] = self.deconv5_2(features[2])
features[2] = self.relu5(self.bn5(features[2]))
upsample = features[2]
right = self.resnet(x[:, 3:6, :, :])
# 从最高层开始,逐层生成特征图
features = right
# 将右图特征图向上采样,并与下层特征图融合
upsample2 = F.interpolate(features[0], size=(features[1].shape[2], features[1].shape[3]), mode='bilinear')
features[1] = features[1] + self.deconv4_1(upsample2)
features[1] = self.deconv4_2(features[1])
features[1] = self.relu4(self.bn4(features[1]))
upsample2 = F.interpolate(features[1], size=(features[2].shape[2], features[2].shape[3]), mode='bilinear')
features[2] = features[2] + self.deconv5_1(upsample2)
features[2] = self.deconv5_2(features[2])
features[2] = self.relu5(self.bn5(features[2]))
upsample2 = features[2]
#对左图和右图进行特征拼接,执行correlation操作,直接预测光流
concat_feature = torch.cat([upsample, upsample2], dim=1)
concat_feature = F.interpolate(concat_feature, size=(x.shape[2], x.shape[3]), mode='bilinear')#上采样4倍
flow = self.deconv7(concat_feature)
return [flow,]
class FlowNetC(nn.Module):
expansion = 1
def __init__(self,batchNorm = True):
super(FlowNetC,self).__init__()
self.resnet = ResNet(BasicBlock, [2, 2, 2, 2])
self.fpn = FPN(self.resnet)
def forward(self, x):
if self.training:
ret = self.fpn(x)
return ret
else:
ret = self.fpn(x)[0]
return ret
def weight_parameters(self):
return [param for name, param in self.named_parameters() if 'weight' in name]
def bias_parameters(self):
return [param for name, param in self.named_parameters() if 'bias' in name]
def flownetc(data=None):
"""FlowNetS model architecture from the
"Learning Optical Flow with Convolutional Networks" paper (https://arxiv.org/abs/1504.06852)
Args:
data : pretrained weights of the network. will create a new one if not set
"""
model = FlowNetC(batchNorm = False)
if data is not None:
model.load_state_dict(data['state_dict'])
return model