这篇论文是发表在CVPR2018的《Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference》,作者是google brain做Mobilenet的Hartwig Adam和Dmitry Kalenichenko,第一作者是Benoit Jacob。
这篇文章提出了一套量化的流程,让前向预测的时候都是整数运算。同时设计了一套训练流程,能够保持较高的量化网络精度。
方法
量化前向预测
- 量化策略:非对称量化
- $r=S(q-Z)$
- r是real q是quantized
- Z是零点,代表quantized的空间上对应r为0的数值。
- 因为有zero-padding
- 矩阵乘法
- r3=r1r2 分别有不同的S和Z
- $S_3(q_3^{i,k}-Z_3)=\sum_{j}S_1(q_1^{i,j}-Z_1)S_2(q_2^{j,k}-Z_2)$
- $q_3^{i,k}=Z_3+M\sum_{j}(q_1^{i,j}-Z_1)(q_2^{j,k}-Z_2)$
- M是scale factor $M:=S_1 S_2/S_3$
- 实际运算使用一个定点数乘和一个移位 $M=2^{-n}M_0$
- $M_0\in[0.5,1)$是int16或者int32的
- 有一定的损失,但利用定点数乘法把误差降到最低
- M是scale factor $M:=S_1 S_2/S_3$
- $q_3^{i,k}=Z_3+M(NZ_1 Z_2-Z_1 a_2^k -Z_2 a_1^i +\sum_j q_1^{i,j} q_2^{j,k})$
- 其中$a_1:=\sum_{j}q_2^{j,k}$ 整个矩阵运算里面是提前算好了O(N^2)
- 整个计算量在q1q2 O(N^3)
- 实现
- 中间临时数值用int32
- q1q2
- int32+=uint8*uint8
- bias
- $S_{bias}=S_1 S_2, Z_{bias}=0 $
- 算完之后cast到uint8
- 使quantization倾向于使用uint8全部的表示区间
量化训练算法
- 使用量化算法的原因
- 1.原来算法weights的channel-level large differences>100x,导致range较小的那些channel有很大的量化误差
- 2.weights的outlier
- simulate quantization effects in the forward pass of training
- weight and biases stored in FP32
- $q_{sim}(r;a,b,n)=round(\frac{clamp(r;a,b)-a}{s(a,b,n)})s(a,b,n)+a$
- n是量化的位点数目2^8,$s(a,b,n)=\frac{b-a}{n-1}$
- biases are not quantized because they are represented as 32-bit integers
- a,b的确定
- weight: a:=min(w) b:=max(w)
- activation: via EMA(exponential moving average)
- disable activation quantization at start of training to get stable state
- nudge a/b to make sure 0.0 is representable
- batchnorm的处理
- inference: fold to weight of conv $w_{fold}:=\frac{\gamma w}{EMA(\sigma_B)}$
- training: simulate the folding
- 计算图上计算两遍Conv,第一遍浮点把$\sigma_B$和$\mu_B$算出来
- 第二遍把folded weights算出来然后quantize之后conv
实验
- ImageNet ResNet,掉2%
- ReLU6>ReLU
- weights are more sensitive to reduced bit depth than activations.
- it’s better to keep weight and activation bit depths the same.
总结
这篇文章系统性的提出了整个网络非对称量化的方法,非常具有实践意义。在计算图上进行分析修改的方法值得借鉴。分析掉点的原因:channel-wise quantization。数值误差:M的tail和cast到uint8; bias的量化损失。