LambdaLR 不同,完全作为中介接管了 Optimizer。

TransformerAdam + Learning Rate Scheduling 为例。

def get_rate(d_model, step_num, warmup_step):   
    return d_model**(-0.5) * min(step_num**(-0.5), step_num * warmup_step**(-1.5))
 
class NoamOpt:  
    "Optim wrapper that implements rate."  
    def __init__(self, model_size, factor, warmup, optimizer):  
        self.optimizer = optimizer  # 被包装的真实 Optimizer
        self._step = 0  # 当前步数,由于 Noam 公式根据步数,要自行维护
        # 方便监控、状态持久化等功能
        self.warmup = warmup  
        self.factor = factor  
        self.model_size = model_size  
        self._rate = 0
  
    @property  
    def param_groups(self):
        """
        在 PyTorch 中,优化器所有的超参数(学习率、动量等)
        都存在一个名为param_groups 的 **list of dict** 里
    
        通过这个 property,外部代码调用 noam_opt.param_groups 时
        实际是在直接读写底层 Adam 的参数。这让 Wrapper 看起来就像个原生的优化器
        """
        return self.optimizer.param_groups 
  
    def multiply_grads(self, c):  
        """Multiplies grads by a constant *c*."""
        """
        p.grad.data.mul_(c)
        遍历参数组中所有带梯度的参数,直接在内存原地修改梯度值。
        
        常用于梯度累积或者防止梯度爆炸
        """  
        for group in self.param_groups:  
            for p in group['params']:  
                if p.grad is not None:  
                    p.grad.data.mul_(c)  
  
    def step(self):  
        "Update parameters and rate" 
        """
        拦截原生的更新动作,先改学习率,再调用被包装 Optimizer 优化
        """ 
        self._step += 1  
        rate = self.rate()  
        for p in self.param_groups:  
            p['lr'] = rate  # 直接修改底层优化器字典里的学习率
        self._rate = rate  
        self.optimizer.step() # 调用原生接口,让 Adam 用给定学习率更新权重
  
    def rate(self, step = None):  
        "Implement `lrate` above"  
        if step is None:  
            step = self._step  
        return 0 if not step else self.factor * get_rate(self.model_size, step, self.warmup)
        
        
optimizer = NoamOpt(  
    model_size=arch_args.encoder_embed_dim,  
    factor=config.lr_factor,  
    warmup=config.lr_warmup,  
    optimizer=torch.optim.AdamW(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=0.0001))