Pytorch自动混合精度AMP出现Loss变为NaN
在使用AMP的方法训练UNETR和SwinUNETR的时候,发现Loss在一定epoch后突然变为NaN(巨大)。
本篇为调查原因和解决方案的总结。
现象
使用一般没有结合torch.cuda.amp
的训练方法时状态良好,但是加入amp
后,迭代一定次数,出现nan损失:
出现时间与模型有关。
不少博客表示,是在使用CrossEntropyLoss的时候出现的这个问题。
原因
以下是来自网络博客,大家分析的可能的原因:
- float32换位float16进行计算的时候,后者无法承受过小的值的计算,出现错误。
- float16 的最小正数是约 6e-5,当你试图对一个接近 0 的数(例如 softmax 后的概率值)取 log,就容易出现 underflow
可能是涉及到log
计算的时候,发生的问题。
解决
-
不使用AMP直接计算。(简单粗暴好用)
-
在涉及到
log
计算时,从半精度转换回float32:
1 | log_prob = torch.log(prob.float() + 1e-8) # 注意转成 float32,再加 eps |
会损失精度,但能用。
-
使用安全的内建函数:
F.cross_entropy()
比自己写 log(softmax) 更稳定,因为它内部已经做了 float32 处理。 -
防止梯度爆炸
- 添加梯度裁剪
- GradScaler
我加了,还是NaN了。
GradScaler可以解决“除以0”和Loss过大被判断为INF的情况,但是无法解决网络参数变为了NaN的情况。
参考博客:
Pytorch自动混合精度AMP出现Loss变为NaN
https://zhouwentong7.github.io/2025/07/23/Pytorch自动混合精度AMP出现Loss变为NaN/