深度学习训练中的 large batch size 和 learning rate

今天逛知乎的时候看到这个很经典的问题,相信很多人在训练 DNN 的时候都会遇到过,顺便记录总结下各位知乎大佬的回答。

理解SGD、minibatch-SGD 和 GD

在机器学习优化算法中,梯度下降(gradient descent,简称 GD)是最常用的方法之一,简单来说就是在整个训练集中计算当前的梯度,选定一个步长进行更新。GD 的优点是,基于整个数据集得到的梯度,梯度估计相对较准,更新过程更准确。但也有几个缺点,一个是当训练集较大时,GD 的梯度计算较为耗时,二是现代深度学习网络的 loss function 往往是非凸的,基于凸优化理论,这种情况下优化算法只能收敛到局部最小值,因此使用 GD 训练深度神经网络,最终收收敛点很容易落在初始点附近的一个局部最小值区域,不太容易达到较好的收敛性能。

另一个极端是随机梯度下降(stochastic gradient descent),每次计算梯度只用一个样本,这样做的好处是计算快,而且很适合 online-learning 数据流式到达的场景,但缺点是单个实例产生的梯度估计往往很不准,所以得采用很小的 learning rate,而且由于现代的计算框架 CPU/GPU 的多线程工作,单个实例往往很难占满 CPU/GPU 的使用率,导致计算资源浪费。

折中的方案就是 mini-batch,一次采用 batch size 的实例来估计梯度,这样梯度估计相对于 SGD 更准,同时 batch size 能占满 CPU/GPU 的计算资源,又不像 GD 那样计算整个训练集。同时也由于 mini batch 能有适当的梯度噪声,一定程度上缓解 GD 直接掉进了初始点附近的局部最小导致收敛不好的缺点,所以 mini-batch 的方法也最为常用。

关于增大 batch size 对于梯度估计准确度的影响,分析如下:假设 batch size 为 m,对于一个 mini-batch,loss 为:

$$L = \frac{1}{m}\sum_{i=1}^{m} L(x_i, y_i)$$

梯度:

$$g = \frac{1}{m}\sum_{i=1}^{m} g(x_i, y_i)$$

整个 mini-batch 的梯度方差为

$$Var(g) = Var(\frac{1}{m}\sum_{i=1}^{m} g(x_i, y_i))
= \frac{1}{m} Var(g(x_1, y_1))$$

由于每个样本 (xi, yi) 是随机从训练样本集采样得到的,因此样本梯度的方差相等。可以看到 batch size 增大 m 倍,相当于将梯度的方差减少 m 倍,因此梯度更加准确。

如果要保持方差和原来 SGD 一样,相当于给定了这么大的方差带宽容量,那么就可以增大 learning rate,充分利用这个方差容量,在上式中添加 learning rate,同时利用方差的变化公式,得到等式:

$$\frac{1}{m} Var(\sqrt{m} \ast lr \ast g(x_1, y_1)) = Var(lr \ast g(x_1, y_1))$$

因此可将 learning rate 增加 sqrt(m) 倍,以提高训练速度。

large batch 与 learning rate

在分布式训练中,batch size 随着数据并行的 worker 增加而增大,假设 baseline 的batch size 为 B,learning rate 为 lr,训练 epoch 数为 N。如果保持 baseline 的 learning rate,一般不会有较好的收敛速度和精度。原因如下:对于收敛速度,假设 k 个 worker,每次过的 sample 数量为 kB,因此一个 epoch 下的更新次数为 baseline 的 1/k,而每次更新的 lr 不变,所以要达到 baseline 相同的更新次数,则需要增加 epoch 数量,最大需要增加 k*N 个epoch,因此收敛加速倍数会远远低于 k。对于收敛精度,由于增大了 batch size 使梯度估计相较于 badeline 的梯度更加准确,噪音减少,更容易收敛到附近的局部最小,类似于 GD 的效果。为了解决这个问题,一个方法就是增大 lr,因为 batch 变大梯度估计更准,理应比 baseline 的梯度更确信一些,所以增大 lr,利用更准确的梯度多走一点,提高收敛速度。同时增大lr,让每次走的幅度尽量大一些,如果遇到了尖锐的局部最小,还有可能逃出收敛到更好的地方。

但是 lr 不能无限制的增大,原因分析如下。深度神经网络的 loss surface 往往是高维高度非线性的,可以理解为 loss surface 表面凹凸不平,坑坑洼洼,不像 y = x^2 曲线这样光滑,因此基于当前 weight 计算出来的梯度,往前更新的 learing rate 很大的时候,沿着 loss surface 的切线就走了很大一步,有可能大大偏于原有的 loss surface,示例如下图(a)所示,虚线是当前梯度的方向,也就是当前 loss surface 的切线方向,如果learning rate 过大,那这一步沿切线方向就走了很大一步,如果一直持续这样,那很可能就走向了一个错误的loss surface,如图(b)所示。如果是较小的learning rate,每次只沿切线方向走一小步,虽然有些偏差,依然能大致沿着 loss sourface steepest descent曲线想下降,最终收敛到一个不错的局部最小,如图(c)所示。

因此,如何确定 large batch size 与 learing rate 的关系呢?分别比较 baseline 和 k 个 worker 的 large batch 的更新公式,如下:

$$w_{t+k} = w_t - \eta \frac{1}{n}\sum_{j < k}\sum_{x \epsilon \beta_j} \nabla l(x, w_{t + j})$$

$$w_{t+k} = w_t - \eta \frac{1}{kn}\sum_{j < k}\sum_{x \epsilon \beta_j} \nabla l(x, w_{t})$$

以上是 baseline (batch size B) 和 large batch(batch size kB) 的更新公式,第二个中 large batch 过一步的数据量相当于第一个 baseline k 步过的数据量,loss 和梯度都按找过的数据量取平均,因此,为了保证相同的数据量利用率,第二个中的 learning rate 应该为 baseline 的 k 倍,也就是 learning rate 的 linear scale rule。

linear scale rule 有几个约束,其中一个约束是关于 weight 的约束,第一个公式中每一步更新基于的 weight 都是前一步更新过后的 weight,因此相当于小碎步的走,每走一步都是基于目前真实的 weight 计算梯度做更新的,而第二个公式的这一大步(相比 baseline 相当于 k 步)是基于 t 时刻的 weight 来做更新的。如果在这 k 步之内,W(t+j) ~ W(t) 的话,两者近似没有太大问题,也就是 linear scale rule 问题不大,但在 weight 变化较快的时候,会有问题,尤其是模型在刚开始训练的时候,loss 下特别快,weight 变化很快,W(t+j) ~ W(t) 就不满足。因此在初始训练阶段,一般不会直接将 lr 增大为 k 倍,而是从 baseline 的 lr 慢慢 warmup 到 k 倍,让 linear scale rule 不至于违背得那么明显。第二个约束是 lr 不能无限的放大,根据上面的分析,lr 太大直接沿 loss 切线跑得太远,导致收敛出现问题。同时,有文献指出,当 batch size 变大后,得到好的测试结果所能允许的 lr 范围在变小,也就是说,当 batch size 很小时,比较容易找打一个合适的 lr 达到不错的结果,当 batch size 变大后,可能需要精细地找一个合适的 lr 才能达到较好的结果,这也给实际的 large batch 分布式训练带来了困难。

从理论上来说,lr = batch_size * base lr,因为 batch_size 的增大会导致你 update 次数的减少,所以为了达到相同的效果,应该是同比例增大的。但是更大的 lr 可能会导致收敛的不够好,尤其是在刚开始的时候,如果你使用很大的 lr,可能会直接爆炸,所以可能会需要一些 warmup 来逐步的把 lr 提高到你想设定的 lr。实际应用中发现不一定要同比例增长,有时候可能增大到 batch_size/2 倍的效果已经很不错了。

参考资料