自定义一个 Loss Function,接口和普通的模型是一致的。

import torch
import torch.nn as torch
 
class LabelSmoothedCrossEntropyCriterion(nn.Module):  
    def __init__(self, smoothing, ignore_index=None, reduce=True):  
        super().__init__()  
        self.smoothing = smoothing  
        self.ignore_index = ignore_index  
        self.reduce = reduce  
  
    def forward(self, lprobs, target):  
        if target.dim() == lprobs.dim() - 1:  
            target = target.unsqueeze(-1)  
        # nll: Negative log likelihood,the cross-entropy when target is one-hot. following line is same as F.nll_loss  
        nll_loss = -lprobs.gather(dim=-1, index=target)  
        #  reserve some probability for other labels. thus when calculating cross-entropy,  
        # equivalent to summing the log probs of all labels
        smooth_loss = -lprobs.sum(dim=-1, keepdim=True)  
        if self.ignore_index is not None:  
            pad_mask = target.eq(self.ignore_index)  
            nll_loss.masked_fill_(pad_mask, 0.0)  
            smooth_loss.masked_fill_(pad_mask, 0.0)  
        else:  
            nll_loss = nll_loss.squeeze(-1)  
            smooth_loss = smooth_loss.squeeze(-1)  
        if self.reduce:  
            nll_loss = nll_loss.sum()  
            smooth_loss = smooth_loss.sum()  
        # when calculating cross-entropy, add the loss of other labels  
        eps_i = self.smoothing / lprobs.size(-1)  
        loss = (1.0 - self.smoothing) * nll_loss + eps_i * smooth_loss  
        return loss  
  
# generally, 0.1 is good enough  
criterion = LabelSmoothedCrossEntropyCriterion(  
    smoothing=0.1,  
    ignore_index=task.target_dictionary.pad(),  
)

要点:

  1. 数据对齐
if target.dim() == lprobs.dim() - 1:
    target = target.unsqueeze(-1)
  1. 核心算子
nll_loss = -lprobs.gather(dim=-1, index=target)
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
eps_i = self.smoothing / lprobs.size(-1)  
loss = (1.0 - self.smoothing) * nll_loss + eps_i * smooth_loss  

不使用现成的函数,而是通过 gathermatmul 等基础算子实现更精细的控制

  1. 掩码过滤
if self.ignore_index is not None:
    pad_mask = target.eq(self.ignore_index)
    nll_loss.masked_fill_(pad_mask, 0.0)

必须手动把 PAD 位置的 Loss 归零,否则梯度会被噪声淹没。

  1. 还原与归约
if self.reduce:  
    nll_loss = nll_loss.sum()  
    smooth_loss = smooth_loss.sum()  

决定是返回一个 Loss 矩阵还是一个 Loss 标量,前者常用于可视化分析等。