深度学习-Softmax

从概率角度看 Softmax:公式推导、数值稳定性与代码实现

一、引言

在分类任务中,我们常常希望模型输出的是一个概率分布,而非仅仅是 raw logits。Softmax 函数应运而生,它能将任意实数向量“压缩”为概率分布的形式,是神经网络输出层的标配

转换成概率分布之后的各列别的值分布在0~1的区间中,所有类别的概率和为1.

虽然也是激活函数,但是主要用在输出层(多分类任务中)。


二、Softmax 的数学定义

对于输入向量$\mathbf{z} = [z_1, z_2, \dots, z_C]$,Softmax 的定义如下:

$$
Softmax(z_i) = \frac{e^{z_i}} { \sum_{j=1}^{C} e^{z_j}}
$$

其输出满足两个性质:

  • 所有输出为正: $Softmax(z_i) > 0 $
  • 总和为 1:$\sum_i Softmax(z_i) = 1 $

因此,Softmax 的输出可解释为类别的预测概率。


三、为什么 Softmax 能输出概率分布?

  1. 指数函数保证了非负性
    $e^x > 0$ 对任意实数成立。

  2. 归一化使得总和为 1
    将所有指数值除以总和,即可得到一个合法的概率分布。

  3. 可微性 & 梯度友好
    Softmax 函数是可导的,便于神经网络训练中的反向传播。


四、数值稳定性问题与解决方法

当 $z_i$ 很大时,$e^{z_i}$容易导致 数值溢出(overflow)

✅ 解决方法:减去最大值(不影响相对关系)

$$
\text{Softmax}(z_i) = \frac{e^{z_i - \max(z)}}{\sum_{j} e^{z_j - \max(z)}}
$$

这样可以有效避免过大指数导致的溢出。


五、代码实现

✅ Numpy 实现

1
2
3
4
5
6
7
8
9
10
11
import numpy as np

def softmax(x):
x = x - np.max(x, axis=-1, keepdims=True) # 提升数值稳定性
exp_x = np.exp(x)
return exp_x / np.sum(exp_x, axis=-1, keepdims=True)

logits = np.array([2.0, 1.0, 0.1])
probs = softmax(logits)
print(probs) # 输出概率分布

Pytorch

1
2
3
4
5
6
7
import torch
import torch.nn.functional as F

logits = torch.tensor([[2.0, 1.0, 0.1]])
probs = F.softmax(logits, dim=1)
print(probs)

六、Softmax vs Argmax vs Sigmoid

函数 输出含义 常见用途 是否可导 输出范围
Softmax 多分类概率分布 多分类任务 ✅ 可导 (0, 1),总和为1
Argmax 最大值的索引 推理阶段,选类别 ❌ 不可导 整数索引值
Sigmoid 每类独立概率 多标签分类、多输出 ✅ 可导 (0, 1)

⚠️ 注意:Sigmoid 用于 多标签任务,Softmax 用于 互斥多分类任务

七、Softmax输出可视化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

# 输入范围:-5 到 5
x = torch.linspace(-5, 5, steps=100)

# 构造 3 类 logits:一个随 x 变化,另两个固定(0 和 -x)
logits = torch.stack([x, torch.zeros_like(x), -x], dim=1)

# 使用 softmax 计算概率(沿类别维度 dim=1)
probs = F.softmax(logits, dim=1)

# 绘图
plt.plot(x.numpy(), probs[:, 0].numpy(), label="Class 0")
plt.plot(x.numpy(), probs[:, 1].numpy(), label="Class 1")
plt.plot(x.numpy(), probs[:, 2].numpy(), label="Class 2")
plt.xlabel("Input value (logit z)")
plt.ylabel("Softmax Probability")
plt.title("Softmax Output vs Input (PyTorch)")
plt.legend()
plt.grid(True)
plt.show()

可视化结果

  • 当某一类 logit 增大时(比如 x 变大),对应的 softmax 概率也会提高。
  • 而其他类的概率会下降。
  • Logits 是模型的“原始判断力道”
  • Softmax 是把这些力道转成“概率的形式”
  • 得分高的类别,变成的概率也会高!

Think:为什么不直接对 logits 用 CrossEntropy?
因为交叉熵定义的是对概率分布的比较:

  • logits 是“未归一化的原始分数”
  • 概率分布要求的是 softmax 后的“归一化值”

所以你不能直接把 logits 当概率用——需要经过 softmax。

不过 PyTorch 为了效率,让你直接传 logits,它内部帮你做了。


本文由GPT指导生成

作者

Zhou

发布于

2025-04-06

更新于

2025-04-06

许可协议

评论

+ + +