与 LambdaLR 不同,完全作为中介接管了 Optimizer。
以 Transformer 的 Adam + Learning Rate Scheduling 为例。
def get_rate(d_model, step_num, warmup_step):
return d_model**(-0.5) * min(step_num**(-0.5), step_num * warmup_step**(-1.5))
class NoamOpt:
"Optim wrapper that implements rate."
def __init__(self, model_size, factor, warmup, optimizer):
self.optimizer = optimizer # 被包装的真实 Optimizer
self._step = 0 # 当前步数,由于 Noam 公式根据步数,要自行维护
# 方便监控、状态持久化等功能
self.warmup = warmup
self.factor = factor
self.model_size = model_size
self._rate = 0
@property
def param_groups(self):
"""
在 PyTorch 中,优化器所有的超参数(学习率、动量等)
都存在一个名为param_groups 的 **list of dict** 里
通过这个 property,外部代码调用 noam_opt.param_groups 时
实际是在直接读写底层 Adam 的参数。这让 Wrapper 看起来就像个原生的优化器
"""
return self.optimizer.param_groups
def multiply_grads(self, c):
"""Multiplies grads by a constant *c*."""
"""
p.grad.data.mul_(c)
遍历参数组中所有带梯度的参数,直接在内存原地修改梯度值。
常用于梯度累积或者防止梯度爆炸
"""
for group in self.param_groups:
for p in group['params']:
if p.grad is not None:
p.grad.data.mul_(c)
def step(self):
"Update parameters and rate"
"""
拦截原生的更新动作,先改学习率,再调用被包装 Optimizer 优化
"""
self._step += 1
rate = self.rate()
for p in self.param_groups:
p['lr'] = rate # 直接修改底层优化器字典里的学习率
self._rate = rate
self.optimizer.step() # 调用原生接口,让 Adam 用给定学习率更新权重
def rate(self, step = None):
"Implement `lrate` above"
if step is None:
step = self._step
return 0 if not step else self.factor * get_rate(self.model_size, step, self.warmup)
optimizer = NoamOpt(
model_size=arch_args.encoder_embed_dim,
factor=config.lr_factor,
warmup=config.lr_warmup,
optimizer=torch.optim.AdamW(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=0.0001))