如何在PyTorch中自定义优化器
Optimizer类
所有优化器都继承自Optimizer类。
Optimizer初始化的时候接受params
和defaults
两个变量,分别是模型参数和优化器的超参。
Optimizer的变量如下:
变量名称 |
作用 |
state |
储存状态变量 |
param_groups |
储存每一组参数的值、梯度和超参 |
其中state是一个字典,储存的是优化器中的状态变量,比如Adam
中的mt和vt,其结构如下:
1 2 3 4 5 6 7 8 9 10
| state -- p1 -- state 1 -- Tensor -- state 2 -- ... -- state n -- p2 -- ... -- pn
|
假设一个模型被划分为了N个group,优化器有n个超参数,则param_groups的结构如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| param_groups -- param_group 1 -- params -- p1 -- data -- grad -- ... -- p2 -- ... -- pk -- hyperparameter 1 -- hyperparameter 2 -- ... -- hyperparameter n -- param_group 2 -- ... -- param group N
|
其中不同的param group
是人为划分的,用于给不同的层不同的超参数,类似下面的代码(默认情况下只有一个param group
):
1 2 3 4
| optimizer = Adam([ {'params': model.layer1.parameters(), 'lr': 0.01}, {'params': model.layer2.parameters(), 'lr': 0.001} ])
|
每个param_group
都是一个字典,包含params
和defaults
中所有超参。params
中包含若干个模型参数,简记为p
,如fc1.weight
,fc1.bias
等,p
是一个Tensor
,拥有所有Tensor
的方法。在优化器中,会更关注p.data
和p.grad
,前者储存的是模型的参数,后者储存的是该参数的梯度。
Optimizer类的所有方法如下:
方法名称 |
作用 |
Optimizer.add_param_group |
向Optimizer的param_groups 中添加一个param_group |
Optimizer.load_state_dict |
加载state 变量 |
Optimizer.state_dict |
以字典形式返回optimizer中的state |
Optimizer.step |
进行一次优化操作(参数更新) |
Optimizer.zero_grad |
重置所有模型参数的梯度 |
这些方法主要是在训练模型的时候进行的操作,如Optimizer.step()
和Optimizer.zero_grad()
在自定义优化器的过程中都不是很重要。
In-Place函数
在优化器中,为了节省内存或显存,会大量使用in-place函数来原地操作各种state
和param
,一些常见的in-place函数如下:
函数名称 |
作用 |
add_(x,alpha) |
y←y+αx |
mul_(x) |
y←x⊙y |
addcmul_(x1,x2,value) |
y←y+vx1⊙x2 |
clamp(min,max) |
y←min(max(y,min),max) |
sqrt() |
y←y |
lerp(x,weight) |
y←x+w(y−x) |
代码结构与示例:AdamW
接下来以AdamW
为例介绍一下如何自定义一个优化器,并给出一个通用的自定义优化器的代码结构。
首先给出AdamW
的数学表达:

step0. 导入必要的库
1 2 3 4
| import torch from torch import Tensor from torch.optim import Optimizer, paramsT from typing import Tuple, Union, List
|
其中,Tensor
,params_t
,Tuple
,Union
,List
都是用来判断类型的,paramsT
是判断optimizer
接收的参数params
是否是正确的类型。
1
| paramsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]
|
step1. 初始化
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| class AdamW(Optimizer): def __init__(self, params: paramsT, lr: Union[float, Tensor] = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 1e-2): defaults = dict( lr=lr, betas=betas, eps=eps, weight_decay=weight_decay ) super().__init__(params, defaults)
|
step2. step函数框架
step()
接收一个参数closure
,其作用是当需要多次评价损失函数的时候返回损失函数,在绝大多数情况下,都不会用到closure
,因此我们不会详细展开(详情见官方文档:https://pytorch.org/docs/stable/optim.html#optimizer-step-closure)。
step函数整体而言分为三个部分,定义列表、初始化列表、参数更新。本节主要定义step函数的框架。初始化列表和参数更新会在后续小节进行讨论。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
| @torch.no_grad() def step(self, closure=None): loss = None if closure is not None: loss = closure() for group in self.param_groups: self._init_group(...) self._single_tensor_adamw(...)
|
接下来我们看一下AdamW
该如何按照这个框架来实现:
从数学表达式可以得知,AdamW
的每一步需要接收2个模型相关参数:param
、grad
,4个超参数,以及3个state:mt、vt、t,在代码中,我们分别使用exp_avg
、exp_avg_sq
和step
来表示这三个state。其中,超参数对于同一个组里的所有param
都是一样的,而其余变量则随param
变化而变化。因此需要定义5个列表来储存这些变量。
1 2 3 4 5
| params_with_grad = [] grads = [] state_steps = [] exp_avgs = [] exp_avg_sqs = []
|
初始化所有列表:由于所有信息都在group
中,因此传入group
和5个空列表,将group
中的信息提取到列表里:
1 2 3 4 5 6 7 8
| self._init_group( group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps, )
|
参数更新:参数更新的时候就不再传入groups
了,传入之前初始化的5个列表,并将所有超参数分开传入。
1 2 3 4 5 6 7 8 9 10
| self._single_tensor_adamw(params=params_with_grad, grads=grads, exp_avgs=exp_avgs, exp_avg_sqs=exp_avg_sqs, state_steps=state_steps, beta1=group["betas"][0], beta2=group["betas"][1], lr=group["lr"], weight_decay=group["weight_decay"], eps=group["eps"])
|
综上,AdamW
的step
函数如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
| @torch.no_grad() def step(self, closure=None): loss = None if closure is not None: loss = closure() for group in self.param_groups: params_with_grad = [] grads = [] state_steps = [] exp_avgs = [] exp_avg_sqs = [] self._init_group( group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps, ) self._single_tensor_adamw(params=params_with_grad, grads=grads, exp_avgs=exp_avgs, exp_avg_sqs=exp_avg_sqs, state_steps=state_steps, beta1=group["betas"][0], beta2=group["betas"][1], lr=group["lr"], weight_decay=group["weight_decay"], eps=group["eps"])
|
step3. 初始化列表
主要解决t=0时刻state
中的变量的初始化的问题。剩下就是把所有的变量加到列表中。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
| def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps): for p in group['params']: if p.grad is None: continue params_with_grad.append(p) grads.append(p.grad) state = self.state[p] if len(state) == 0: state['step'] = torch.tensor(0.0, dtype=torch.float64) state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
state_steps.append(state['step']) exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq'])
|
注意当要初始化要分类讨论的时候要加None
占位,确保prams[i]
的所有相关变量都在列表的第i位。
step4. 参数更新
逐行公式与代码对应:
更新滑动平均
mt←β1mt+(1−β1)gt
这里直接用插值函数
1
| exp_avg.lerp_(grad, 1 - beta1)
|
vt←β2vt+(1−β2)gt2
1
| exp_avg_sq.addcmul_(grad, grad, value=1 - beta2)
|
计算bias correction
m^t←(1−β1t)mtv^t←(1−β2t)vt
这里由于不希望占用内存空间来计算m^t和v^t,因此仅计算分母。
1 2
| bias_correction1 = 1 - beta1 ** step_t bias_correction2 = 1 - beta2 ** step_t
|
weight decay
weight decay被单独拿出来进行计算
1
| param.mul_(1 - lr * weight_decay)
|
更新参数
在去除了weight decay后,更新参数的公式如下:
θt←θt−1−ηt(v^t+ϵmt^)
将其使用bisa_correction
和mt,vt表示:
θt←θt−1−(1−β1t)ηtmt((1−β2tvt)+ϵ)−1
1 2 3 4
| step_size = lr / bias_correction1 bias_correction2_sqrt = bias_correction2.sqrt() denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) param.addcdiv_(exp_avg, denom, value=-step_size)
|
第一个和第三个式子分别计算的是(1−β1t)ηt和(1−β2tvt)+ϵ
最终所参数更新部分的所有代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
| def _single_tensor_adamw(self, params: List[Tensor], grads: List[Tensor], exp_avgs: List[Tensor], exp_avg_sqs: List[Tensor], state_steps: List[Tensor], beta1: float, beta2: float, lr: float, weight_decay: float, eps: float): for i, param in enumerate(params): grad = grads[i] exp_avg = exp_avgs[i] exp_avg_sq = exp_avg_sqs[i] step_t = state_steps[i] step_t += 1 param.mul_(1 - lr * weight_decay) exp_avg.lerp_(grad, 1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) bias_correction1 = 1 - beta1 ** step_t bias_correction2 = 1 - beta2 ** step_t
step_size = lr / bias_correction1 bias_correction2_sqrt = bias_correction2.sqrt()
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
param.addcdiv_(exp_avg, denom, value=-step_size)
|