BN的计算公式
BN中均值与方差的计算
所以对于输入x: b,c,h,w
则 mean: 1,c,1,1
var: 1,c,1,1
代码
class BatchNorm(nn.Module):
def __init__(self, num_features, num_dims):
super().__init__()
if num_dims == 2:
shape = (1, num_features)
else:
shape = (1, num_features, 1, 1)
self.gamma = nn.Parameter(torch.ones(shape))
self.beta = nn.Parameter(torch.zeros(shape))
self.moving_mean = torch.zeros(shape)
self.moving_var = torch.ones(shape)
def forward(self, x, momentum=0.9, eps=1e-5):
if self.training:
assert len(x.shape) in (2, 4)
if len(x.shape) == 2:
mean = x.mean(dim=0, keepdim=True)
var = x.var(dim=0, keepdim=True)
else:
mean = x.mean(dim=(0, 2, 3), keepdim=True)
var = x.var(dim=(0, 2, 3), keepdim=True)
x_hat = (x - mean) / torch.sqrt(var + eps)
self.moving_mean = momentum * self.moving_mean + (1.0 - momentum) * mean
self.moving_var = momentum * self.moving_var + (1.0 - momentum) * var
else:
x_hat = (x - self.moving_mean) / torch.sqrt(self.moving_var + eps)
out = self.gamma * x_hat + self.beta
return out