损失函数是模型训练的"评分器",但更关键的是它通过梯度告诉优化器往哪个方向调参数。很多人能背出 MSE 和交叉熵的公式,却说不清为什么回归用前者、分类用后者,也说不清为什么交叉熵配 softmax 时梯度形式会变得异常干净。这篇我们从直觉、机制、公式一路拆到工程边界。

直觉:损失函数是优化的"接口"

训练的本质是解一个优化问题:

θ=argminθE(x,y)D[L(fθ(x),y)]\theta^* = \arg\min_\theta \; \mathbb{E}_{(x,y)\sim D}\,[\,\mathcal{L}(f_\theta(x),\, y)\,]

模型 fθf_\theta 输出预测,损失 L\mathcal{L} 把"预测和真值差多少"压成一个标量。但前向算出来的标量本身并不驱动训练,真正驱动训练的是 θL\nabla_\theta \mathcal{L}。所以选损失函数,本质上是在选梯度的形状——它决定了误差如何反传、不同样本贡献多大、训练是否稳定。

MSE:高斯噪声下的最大似然

均方误差(Mean Squared Error):

LMSE=1Ni=1N(y^iyi)2\mathcal{L}_{\text{MSE}} = \frac{1}{N}\sum_{i=1}^{N}(\hat{y}_i - y_i)^2

它不是随便定义的。如果假设观测 y=fθ(x)+ϵy = f_\theta(x) + \epsilon,其中 ϵN(0,σ2)\epsilon \sim \mathcal{N}(0, \sigma^2),那么对数似然最大化等价于最小化 (y^y)2\sum (\hat{y}-y)^2MSE = 高斯噪声假设下的极大似然估计,这是它在回归任务里的理论根基。

对单个样本求导:

Ly^=2(y^y)\frac{\partial \mathcal{L}}{\partial \hat{y}} = 2(\hat{y} - y)

梯度正比于残差——预测离得越远,修正越猛。这看似合理,但也是 MSE 的坑:误差被平方放大,对离群点(outlier)极度敏感。一个标注错误的极端样本,可能贡献整个 batch 的大部分梯度。这是工程里 MSE 经常被 Huber loss(小误差用平方、大误差用线性)替代的原因。

交叉熵:分类任务为什么不能用 MSE

分类输出的是概率分布。交叉熵衡量"用预测分布 qq 去编码真实分布 pp 需要多少额外比特":

LCE=cpclogqc\mathcal{L}_{\text{CE}} = -\sum_{c} p_c \log q_c

在常见的单标签分类里 pp 是 one-hot,所以塌缩成:

LCE=logqy\mathcal{L}_{\text{CE}} = -\log q_{y}

只惩罚"真实类别预测概率"的负对数。qy1q_y \to 1 时损失 0\to 0qy0q_y \to 0 时损失 +\to +\infty——对"自信地犯错"给出极重的惩罚。

为什么分类不用 MSE?两个原因。其一,分类的噪声模型是分类分布(categorical),其极大似然对应的就是交叉熵,而非高斯。其二是梯度问题:分类输出通常过 sigmoid/softmax,如果再套 MSE,链式法则里会乘上激活函数的导数 σ(z)\sigma'(z)。当输出饱和(接近 0 或 1)时 σ0\sigma' \approx 0,梯度消失,错得越离谱学得越慢。这是反直觉且致命的。

softmax + 交叉熵:梯度为什么这么干净

把 softmax 和交叉熵当成一个整体看。设 logits 为 zzqc=ezckezkq_c = \frac{e^{z_c}}{\sum_k e^{z_k}},损失 L=logqy\mathcal{L} = -\log q_y。对 logit zjz_j 求导,经过一番代数(softmax 的雅可比 qczj=qc(δcjqj)\frac{\partial q_c}{\partial z_j} = q_c(\delta_{cj} - q_j))会得到惊人简洁的结果:

Lzj=qjpj\frac{\partial \mathcal{L}}{\partial z_j} = q_j - p_j

即"预测概率 − 真实概率"。激活导数项被完美约掉,不会饱和。这正是几乎所有框架把 softmax 和交叉熵合并成一个算子(如 PyTorch 的 CrossEntropyLoss、TensorFlow 的 softmax_cross_entropy_with_logits)的原因:既数值稳定,又利用这个干净的梯度形式。

1
2
3
4
5
6
7
8
9
10
11
12
import torch
import torch.nn.functional as F

logits = torch.tensor([[2.0, 0.5, -1.0]], requires_grad=True)
target = torch.tensor([0]) # 真实类别索引,注意是 index 不是 one-hot

loss = F.cross_entropy(logits, target) # 内部 = log_softmax + nll_loss
loss.backward()

# 验证:grad == softmax(logits) - onehot(target)
print(logits.grad) # [[-0.34, 0.30, 0.04]] 量级
print(F.softmax(logits, dim=1) - F.one_hot(target, 3))

数值稳定:log-sum-exp 与永远别手写两步

直接算 eze^{z}zz 大时会溢出到 inf。框架内部用 log-sum-exp 技巧:先减去最大值再取指数,

logkezk=m+logkezkm,m=maxkzk\log\sum_k e^{z_k} = m + \log\sum_k e^{z_k - m},\quad m = \max_k z_k

这样最大项变成 e0=1e^0 = 1,不会溢出,结果数学上完全等价。踩坑提醒:永远不要自己先 softmaxlognll,那会丢精度甚至出 NaN;用框架的融合算子,传入原始 logits。同理,二分类用 BCEWithLogitsLoss 而不是 Sigmoid + BCELoss

工程权衡与常见误区

  • 不要对 logits 重复加 softmax。 把已经过 softmax 的概率再喂给 CrossEntropyLoss,相当于做了两次,梯度信号被严重压扁,模型"学不动"。这是新手最高频的 bug 之一。
  • 标签平滑(label smoothing)。 把 one-hot 的目标从 11 改成 1ε1-\varepsilon、其余类分 ε/(K1)\varepsilon/(K-1),避免模型对训练标签过度自信、抑制过拟合,代价是校准(calibration)行为改变。
  • 类别不平衡。 当某类样本占比极低,平凡的交叉熵会被多数类主导。常用 class weight 或 focal loss:(1qy)γ(1-q_y)^\gamma 这个调制因子会自动降低"已经分对的简单样本"的权重。
  • 回归别无脑 MSE。 有离群点用 Huber/MAE;输出有界或是比例时考虑换参数化。MSE 对量纲敏感,记得做目标归一化。
  • 损失值不可直接横向比较。 交叉熵的绝对数值依赖类别数和分布,不能拿不同任务的 loss 比"谁训得好",要看相对趋势和验证指标。

小结

MSE 和交叉熵不是两条孤立公式,而是两种噪声假设(高斯 vs 分类)下极大似然的自然产物。真正决定训练好坏的是它们的梯度:MSE 梯度正比残差但惧离群点;softmax-交叉熵的梯度收敛到优雅的 qpq - p,既不饱和又数值友好——前提是你用融合算子、喂原始 logits、别把 softmax 做两遍。理解"损失即梯度接口",比记住公式本身重要得多。