交叉熵的数学推导与代码实现

简单介绍交叉熵(Cross Entropy)

基本概念

什么是交叉熵?

交叉熵(Cross Entropy)是深度学习中最常见的损失函数之一,尤其在分类任务中被广泛应用。它结合了信息论与概率论的思想,能有效衡量模型输出分布与真实分布之间的差异。

信息熵

信息熵:衡量分布不确定性的指标。
$$
H(X) = -\sum^{n}_{i=1}p(x_i)log(p(x_i))
$$

  • X: 事件
  • $p(x_i)$:发生第i种现象的可能性。

系统越是有序,信息熵越低,反之越混乱值越高。

交叉熵(Cross Entropy)

交叉熵:用于衡量两个概率分布之间的距离:

$$
H(p,q) = -\sum^{n}_{i=1}p(x_i)log(q(x_i))
$$

  • $p(x_i)$:真实分布(ground truth)
  • $q(x_i)$:模型预测分布

若p=q:交叉熵最小。

数学推导

二分类任务(Binary Classification)

真实标签$∈{0,1}$, 模型的输出预测为$\hat y ∈(0,1)$,损失函数:

$$
\mathcal{L} = -[y \log(\hat{y}) + (1 - y)\log(1 - \hat{y})]
$$

多分类(Multi-class Classification)

对于 𝐶类问题,模型输出为经过 softmax 的概率分布 $\hat y$,真实标签为 one-hot 向量 𝑦,损失函数为:
$$\mathcal{L} = -\sum_{i=1}^C y_i \log(\hat{y}i)
$$
由于one-hot只有目标类为1,公式简化为:
$$\mathcal{L} = -\log(\hat{y}
{\text{target}})
$$

代码实现

Numpy 实现 Softmax + Cross Entropy

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch

def softmax(logits):
"""
计算 softmax 函数,将 logits 转换为概率分布
:param logits: 未经过 softmax 的 logits 张量
:return: 应用 softmax 之后的概率张量
"""
exp_logits = torch.exp(logits)
return exp_logits / torch.sum(exp_logits, dim=-1, keepdim=True)

def cross_entropy_loss_manual(logits, target):
"""
手动实现交叉熵损失函数
:param logits: 未经过 softmax 的 logits 张量 (batch_size, num_classes)
:param target: 真实标签张量 (batch_size),为类别的索引而不是 one-hot
:return: 计算得到的交叉熵损失
"""
# 先计算 softmax 输出
probs = softmax(logits)

# 取出真实标签对应的概率值 (用 gather 选择每个样本正确标签对应的概率)
log_probs = torch.log(probs)
target_log_probs = log_probs[torch.arange(len(target)), target]

# 计算平均交叉熵损失
loss = -torch.mean(target_log_probs)

return loss

# 测试手动 CrossEntropyLoss
logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.5, 1.5]], dtype=torch.float32)
target = torch.tensor([0, 2], dtype=torch.int64) # 标签为类别的索引
manual_ce_loss = cross_entropy_loss_manual(logits, target)
print("手动实现交叉熵损失结果: ", manual_ce_loss)

Pytorch实现(自动处理LogSoftmax)

官方文档

1
2
3
4
5
6
7
8
import torch
import torch.nn.functional as F

logits = torch.tensor([[2.0, 1.0, 0.1]])
targets = torch.tensor([0]) # 真实标签为类别索引

loss = F.cross_entropy(logits, targets)
print(f"Loss: {loss.item()}")

⚠️ 注意:F.cross_entropy() 自动包括了 softmax 和 log 操作,因此输入应该是未归一化的 logits。

补充:Label Smoothing对交叉熵的影响

什么是Lable Smoothing?

Label Smoothing 是一种正则化技术,用于避免模型过度自信地预测某个类别。在 one-hot 编码中,真实标签为:

1
[0, 0, 1, 0]  # 100% 相信第 3 类

加入 label smoothing 后,标签会被“平滑”成:

1
[0.1, 0.1, 0.7, 0.1]  # 不再绝对确信

数学形式:

$$
y^{\text{smooth}}_i = (1 - \epsilon) \cdot y_i + \frac{\epsilon}{C}
$$

  • $\epsilon$:平滑因子
  • C:类别数

好处

  • 减少过拟合
  • 提升泛化能力
  • 缓解softmax饱和(梯度消失)

Pytorch

从 PyTorch 1.10 开始,torch.nn.CrossEntropyLoss 支持 label_smoothing 参数:

1
2
3
4
5
6
7
8
9
import torch
import torch.nn as nn

logits = torch.tensor([[2.0, 1.0, 0.1]])
targets = torch.tensor([0])

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
loss = criterion(logits, targets)
print(f"Loss with label smoothing: {loss.item()}")

补充:
Focal loss:处理类别不平衡

作者

Zhou

发布于

2025-04-05

更新于

2025-04-05

许可协议

评论

+ + +