如何在PyTorch中自定义优化器

Optimizer类

所有优化器都继承自Optimizer类。

Optimizer初始化的时候接受paramsdefaults两个变量,分别是模型参数和优化器的超参。

Optimizer的变量如下:

变量名称 作用
state 储存状态变量
param_groups 储存每一组参数的值、梯度和超参

其中state是一个字典,储存的是优化器中的状态变量,比如Adam中的mtm_tvtv_t,其结构如下:

1
2
3
4
5
6
7
8
9
10
state
-- p1
-- state 1
-- Tensor
-- state 2
-- ...
-- state n
-- p2
-- ...
-- pn

假设一个模型被划分为了NN个group,优化器有nn个超参数,则param_groups的结构如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
param_groups
-- param_group 1
-- params
-- p1
-- data
-- grad
-- ...
-- p2
-- ...
-- pk
-- hyperparameter 1
-- hyperparameter 2
-- ...
-- hyperparameter n
-- param_group 2
-- ...
-- param group N

其中不同的param group是人为划分的,用于给不同的层不同的超参数,类似下面的代码(默认情况下只有一个param group):

1
2
3
4
optimizer = Adam([
{'params': model.layer1.parameters(), 'lr': 0.01},
{'params': model.layer2.parameters(), 'lr': 0.001}
])

每个param_group都是一个字典,包含paramsdefaults中所有超参。params中包含若干个模型参数,简记为p,如fc1.weightfc1.bias等,p是一个Tensor,拥有所有Tensor的方法。在优化器中,会更关注p.datap.grad,前者储存的是模型的参数,后者储存的是该参数的梯度。

Optimizer类的所有方法如下:

方法名称 作用
Optimizer.add_param_group 向Optimizer的param_groups中添加一个param_group
Optimizer.load_state_dict 加载state变量
Optimizer.state_dict 以字典形式返回optimizer中的state
Optimizer.step 进行一次优化操作(参数更新)
Optimizer.zero_grad 重置所有模型参数的梯度

这些方法主要是在训练模型的时候进行的操作,如Optimizer.step()Optimizer.zero_grad()在自定义优化器的过程中都不是很重要。

In-Place函数

在优化器中,为了节省内存或显存,会大量使用in-place函数来原地操作各种stateparam,一些常见的in-place函数如下:

函数名称 作用
add_(x,alpha) yy+αx\mathbf{y} \leftarrow \mathbf{y} + \alpha \mathbf{x}
mul_(x) yxy\mathbf{y} \leftarrow \mathbf{x}\odot\mathbf{y}
addcmul_(x1,x2,value) yy+vx1x2\mathbf{y}\leftarrow \mathbf{y}+v\mathbf{x}_1\odot \mathbf{x}_2
clamp(min,max) ymin(max(y,min),max)\mathbf{y}\leftarrow \min(\max(\mathbf{y},min),max)
sqrt() yy\mathbf{y}\leftarrow \sqrt{\mathbf{y}}
lerp(x,weight) yx+w(yx)\mathbf{y}\leftarrow \mathbf{x}+w(\mathbf{y}-\mathbf{x})

代码结构与示例:AdamW

接下来以AdamW为例介绍一下如何自定义一个优化器,并给出一个通用的自定义优化器的代码结构。

首先给出AdamW的数学表达:

step0. 导入必要的库

1
2
3
4
import torch
from torch import Tensor
from torch.optim import Optimizer, paramsT
from typing import Tuple, Union, List

其中,Tensor,params_t,Tuple,Union,List都是用来判断类型的,paramsT是判断optimizer接收的参数params是否是正确的类型。

1
paramsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]

step1. 初始化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class AdamW(Optimizer):
def __init__(self,
params: paramsT,
lr: Union[float, Tensor] = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 1e-2):
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay
)
super().__init__(params, defaults)

step2. step函数框架

step()接收一个参数closure,其作用是当需要多次评价损失函数的时候返回损失函数,在绝大多数情况下,都不会用到closure,因此我们不会详细展开(详情见官方文档:https://pytorch.org/docs/stable/optim.html#optimizer-step-closure)。

step函数整体而言分为三个部分,定义列表、初始化列表、参数更新。本节主要定义step函数的框架。初始化列表和参数更新会在后续小节进行讨论。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
# ---------------------------------------------- #
# 在这里定义一系列的列表
# 将所有模型参数、梯度、state、超参数等都放到列表中
# ---------------------------------------------- #
# ...
# ---------------------------------------------- #
# 在这里初始化前面定义的所有列表
# ---------------------------------------------- #
self._init_group(...)
# ---------------------------------------------- #
# 在这里进行参数更新
# ---------------------------------------------- #
self._single_tensor_adamw(...)

接下来我们看一下AdamW该如何按照这个框架来实现:

从数学表达式可以得知,AdamW的每一步需要接收2个模型相关参数:paramgrad,4个超参数,以及3个state:mtm_tvtv_ttt,在代码中,我们分别使用exp_avgexp_avg_sqstep来表示这三个state。其中,超参数对于同一个组里的所有param都是一样的,而其余变量则随param变化而变化。因此需要定义5个列表来储存这些变量。

1
2
3
4
5
params_with_grad = []
grads = []
state_steps = []
exp_avgs = []
exp_avg_sqs = []

初始化所有列表:由于所有信息都在group中,因此传入group和5个空列表,将group中的信息提取到列表里:

1
2
3
4
5
6
7
8
self._init_group(
group,
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
state_steps,
)

参数更新:参数更新的时候就不再传入groups了,传入之前初始化的5个列表,并将所有超参数分开传入。

1
2
3
4
5
6
7
8
9
10
self._single_tensor_adamw(params=params_with_grad,
grads=grads,
exp_avgs=exp_avgs,
exp_avg_sqs=exp_avg_sqs,
state_steps=state_steps,
beta1=group["betas"][0],
beta2=group["betas"][1],
lr=group["lr"],
weight_decay=group["weight_decay"],
eps=group["eps"])

综上,AdamWstep函数如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
# ---------------------------------------------- #
# 在这里定义一系列的列表
# 将所有模型参数、梯度、state、超参数等都放到列表中
# ---------------------------------------------- #
params_with_grad = []
grads = []
state_steps = []
exp_avgs = []
exp_avg_sqs = []
# ---------------------------------------------- #
# 在这里初始化前面定义的所有列表
# ---------------------------------------------- #
self._init_group(
group,
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
state_steps,
)
# ---------------------------------------------- #
# 在这里进行参数更新
# ---------------------------------------------- #
self._single_tensor_adamw(params=params_with_grad,
grads=grads,
exp_avgs=exp_avgs,
exp_avg_sqs=exp_avg_sqs,
state_steps=state_steps,
beta1=group["betas"][0],
beta2=group["betas"][1],
lr=group["lr"],
weight_decay=group["weight_decay"],
eps=group["eps"])

step3. 初始化列表

主要解决t=0t=0时刻state中的变量的初始化的问题。剩下就是把所有的变量加到列表中。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def _init_group(self,
group,
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
state_steps):
for p in group['params']:
if p.grad is None:
continue
params_with_grad.append(p)
grads.append(p.grad)

state = self.state[p]
if len(state) == 0:
state['step'] = torch.tensor(0.0, dtype=torch.float64)
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)

state_steps.append(state['step'])
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])

注意当要初始化要分类讨论的时候要加None占位,确保prams[i]的所有相关变量都在列表的第i位。

step4. 参数更新

逐行公式与代码对应:

更新滑动平均

mtβ1mt+(1β1)gtm_t \leftarrow \beta_1 m_t+(1-\beta_1)g_t

这里直接用插值函数

1
exp_avg.lerp_(grad, 1 - beta1)

vtβ2vt+(1β2)gt2v_t\leftarrow \beta_2 v_t+(1-\beta_2)g_t^2

1
exp_avg_sq.addcmul_(grad, grad, value=1 - beta2)

计算bias correction

m^tmt(1β1t)v^tvt(1β2t)\begin{gather*} \hat{m}_t\leftarrow \frac{m_t}{(1-\beta_1^t)}\\ \hat{v}_t\leftarrow \frac{v_t}{(1-\beta_2^t)} \end{gather*}

这里由于不希望占用内存空间来计算m^t\hat{m}_tv^t\hat{v}_t​,因此仅计算分母。

1
2
bias_correction1 = 1 - beta1 ** step_t
bias_correction2 = 1 - beta2 ** step_t

weight decay

weight decay被单独拿出来进行计算

1
param.mul_(1 - lr * weight_decay)

更新参数

在去除了weight decay后,更新参数的公式如下:

θtθt1ηt(mt^v^t+ϵ)\theta_t \leftarrow \theta_{t-1}-\eta_t(\frac{\hat{m_t}}{\sqrt{\hat{v}_t}+\epsilon})

将其使用bisa_correctionmtm_tvtv_t表示:

θtθt1ηt(1β1t)mt((vt1β2t)+ϵ)1\theta_t\leftarrow \theta_{t-1} - \frac{\eta_t}{(1-\beta_1^t)}m_t\Big(\big(\frac{\sqrt{v_t}}{\sqrt{1-\beta_2^t} }\big)+\epsilon\Big)^{-1}

1
2
3
4
step_size = lr / bias_correction1
bias_correction2_sqrt = bias_correction2.sqrt()
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
param.addcdiv_(exp_avg, denom, value=-step_size)

第一个和第三个式子分别计算的是ηt(1β1t)\frac{\eta_t}{(1-\beta_1^t)}(vt1β2t)+ϵ\big(\frac{\sqrt{v_t}}{\sqrt{1-\beta_2^t} }\big)+\epsilon

最终所参数更新部分的所有代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def _single_tensor_adamw(self,
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
eps: float):
for i, param in enumerate(params):
grad = grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step_t = state_steps[i]
# update step
step_t += 1
# weight decay
param.mul_(1 - lr * weight_decay)
# update exp_avg and exp_avg_sq
exp_avg.lerp_(grad, 1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
# step = _get_value(step_t)
bias_correction1 = 1 - beta1 ** step_t
bias_correction2 = 1 - beta2 ** step_t

step_size = lr / bias_correction1
bias_correction2_sqrt = bias_correction2.sqrt()

denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)

param.addcdiv_(exp_avg, denom, value=-step_size)