? ? ? ? 前段时间想把我模型的输入由DWT子带改为分块的图像块,一顿魔改后,模型跑着跑着损失就朝着奇怪的方向跑去了:要么突然增大,要么变为NAN。
?
?
? ? ? ? ?为什么训练损失会突然变为NAN呢?这个作者将模型训练过程中loss为NAN或INF的原因解释得好好详尽(感谢):Pytorch训练模型损失Loss为Nan或者无穷大(INF)原因_pytorch loss nan-CSDN博客https://blog.csdn.net/ytusdc/article/details/122321907????????我经过输入几番输入打印测试,确认我的输入确实没有问题,那么问题只能出现在模型的前向传播或者反向梯度传播过程中。我跟着这个作者的排查思路,最终定位问题出在梯度反向传播上,于是通过梯度剪裁成功解决NAN问题(我还增大了batch_size的大小,输入修改后,我发现模型运算量减小了,显存支持我每个step跑更大的batch_size了)。pytorch训练过程中出现nan的排查思路_torch判断nan-CSDN博客https://blog.csdn.net/mch2869253130/article/details/111034068修改部分:
if mode == 'train':
# # 1.debug loss
# assert torch.isnan(total_loss).sum() == 0, print(total_loss)
total_loss.backward()
# # 2. 如果loss不是nan,那么说明forward过程没问题,可能是梯度爆炸,所以用梯度裁剪试试
nn.utils.clip_grad_norm(net.parameters(), max_norm=3, norm_type=2)
optim.step()
optim.zero_grad()
梯度剪裁:
????????对超出值域范围的梯度进行约束,避免梯度持续大于1,造成梯度爆炸。(没办法规避梯度消失)
torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type)