自定义一个 Loss Function,接口和普通的模型是一致的。
import torch
import torch.nn as torch
class LabelSmoothedCrossEntropyCriterion(nn.Module):
def __init__(self, smoothing, ignore_index=None, reduce=True):
super().__init__()
self.smoothing = smoothing
self.ignore_index = ignore_index
self.reduce = reduce
def forward(self, lprobs, target):
if target.dim() == lprobs.dim() - 1:
target = target.unsqueeze(-1)
# nll: Negative log likelihood,the cross-entropy when target is one-hot. following line is same as F.nll_loss
nll_loss = -lprobs.gather(dim=-1, index=target)
# reserve some probability for other labels. thus when calculating cross-entropy,
# equivalent to summing the log probs of all labels
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
if self.ignore_index is not None:
pad_mask = target.eq(self.ignore_index)
nll_loss.masked_fill_(pad_mask, 0.0)
smooth_loss.masked_fill_(pad_mask, 0.0)
else:
nll_loss = nll_loss.squeeze(-1)
smooth_loss = smooth_loss.squeeze(-1)
if self.reduce:
nll_loss = nll_loss.sum()
smooth_loss = smooth_loss.sum()
# when calculating cross-entropy, add the loss of other labels
eps_i = self.smoothing / lprobs.size(-1)
loss = (1.0 - self.smoothing) * nll_loss + eps_i * smooth_loss
return loss
# generally, 0.1 is good enough
criterion = LabelSmoothedCrossEntropyCriterion(
smoothing=0.1,
ignore_index=task.target_dictionary.pad(),
)要点:
- 数据对齐
if target.dim() == lprobs.dim() - 1:
target = target.unsqueeze(-1)- 核心算子
nll_loss = -lprobs.gather(dim=-1, index=target)
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
eps_i = self.smoothing / lprobs.size(-1)
loss = (1.0 - self.smoothing) * nll_loss + eps_i * smooth_loss 不使用现成的函数,而是通过 gather、matmul 等基础算子实现更精细的控制
- 掩码过滤
if self.ignore_index is not None:
pad_mask = target.eq(self.ignore_index)
nll_loss.masked_fill_(pad_mask, 0.0)必须手动把 PAD 位置的 Loss 归零,否则梯度会被噪声淹没。
- 还原与归约
if self.reduce:
nll_loss = nll_loss.sum()
smooth_loss = smooth_loss.sum() 决定是返回一个 Loss 矩阵还是一个 Loss 标量,前者常用于可视化分析等。