print('number of model params', sum(p.numel() for p in model.parameters() if p.requires_grad))
sum(p.numel() for p in model.parameters() if p.requires_grad )
可以用来计算参与训练的参数量
model.parameters()
返回模型中所有参数的迭代器。
if p.requires_grad: 这部分使用了一个条件判断,仅考虑那些 requires_grad 属性为 True 的参数。requires_grad
是 PyTorch 中的一个属性,用于指示是否要在参数上计算梯度。
p.numel(): 对于每个满足条件的参数,p.numel() 返回该参数的元素数量,即参数的总数量。numel()
是 PyTorch 张量对象的方法,用于返回张量中元素的总数。
最后,sum(…) 对所有参数的元素数量求和,得到的结果就是模型中所有可学习参数的总数量。