Pytorch自动混合精度AMP出现Loss变为NaN

在使用AMP的方法训练UNETR和SwinUNETR的时候,发现Loss在一定epoch后突然变为NaN(巨大)。

本篇为调查原因和解决方案的总结。

现象

使用一般没有结合torch.cuda.amp的训练方法时状态良好,但是加入amp后,迭代一定次数,出现nan损失:

loss NaN

出现时间与模型有关。

不少博客表示,是在使用CrossEntropyLoss的时候出现的这个问题。

原因

以下是来自网络博客,大家分析的可能的原因:

  • float32换位float16进行计算的时候,后者无法承受过小的值的计算,出现错误。
    • float16 的最小正数是约 6e-5,当你试图对一个接近 0 的数(例如 softmax 后的概率值)取 log,就容易出现 underflow

可能是涉及到log计算的时候,发生的问题。

解决

  1. 不使用AMP直接计算。(简单粗暴好用)

  2. 在涉及到log计算时,从半精度转换回float32:

1
log_prob = torch.log(prob.float() + 1e-8)  # 注意转成 float32,再加 eps

会损失精度,但能用。

  1. 使用安全的内建函数:
    F.cross_entropy() 比自己写 log(softmax) 更稳定,因为它内部已经做了 float32 处理。

  2. 防止梯度爆炸

  • 添加梯度裁剪
  1. GradScaler

我加了,还是NaN了。
GradScaler可以解决“除以0”和Loss过大被判断为INF的情况,但是无法解决网络参数变为了NaN的情况。


参考博客:

  1. 训练深度学习网络时候,出现Nan是什么原因,怎么才能避免

  2. 解决PyTorch半精度(AMP)训练nan问题

  3. Nan Loss with torch.cuda.amp and CrossEntropyLoss

作者

Zhou

发布于

2025-07-23

更新于

2025-07-23

许可协议

评论

+ + +