基于Pytorch构建一个可训练的BNN

1. 前言

通常咱们在构建CNN的时候都是以32位浮点数为主,这样在网络规模很大的状况下就会占用很是大的内存资源。而后咱们这里来理解一下浮点数的构成,一个float32类型的浮点数由一个符号位,8个指数位以及23个尾数为构成,即:html

符号位[ ] + 指数位[ ] [ ] [ ] [ ] [ ] [ ] [ ] [ ] + 尾数[ ]*23python

咱们能够看到,每一个float32浮点数里面一共有 2 23 = 83886087 2^{23}=83886087 个二进制对应表示 2 23 2^{23} 个数,又 1 0 6 < 2 23 < 1 0 8 10^6<2^{23}<10^8 ,因此咱们通常能够精确地表示6位有效数字,可是没法表示 8 8 位有效数字。浮点数有正负因此须要一个符号位来表示,还有 8 8 个指数位来表示指数(指数也是要存储的)因为有正负也就是 [ 127 , 128 ] [-127,128] git

而后完成一个浮点数的加减运算至少有以下过程:github

  • 检查操做数,即若是有一个参与运算的数为 0 0 ,那么直接得出结果。
  • 比较阶码大小完成对阶。
  • 对尾数进行加或减运算。
  • 将结果规格化并进行舍入处理。

关于这一节,能够看这个表述得很是清楚的博客:http://www.cppblog.com/jianjianxiaole/articles/float.htmlweb

从上面对浮点数的介绍来看,若是咱们使用全浮点数的CNN,那么不只仅会占用大量的内存,还须要大量的计算资源。在这种背景下,低比特量化的优点就体验出来了。接下来,咱们就先看一下2016年NIPS提出的《Binarized Neural Networks: Training Neural Networks with Weights andActivations Constrained to +1 or −1》这篇论文,简称BNN,而后再对BNN的Pytorch代码作一个解析。算法

2. BNN的原理

2.1 二值化方案

在训练BNN时,咱们要把网络层的权重和输出值设为1或者-1,下面是论文提出的 2 2 种二值化方法。微信

第一种是直接将大于等于0的参数设置为1,小于0的设置为-1,即:网络

第一种二值化方法

第二种是将绝对值大于1的参数设为1,将绝对值小于1的参数根据距离 + / 1 +/-1 的远近按几率随机置为 + / 1 +/-1 app

第二种二值化方法

其中 σ ( x ) \sigma(x) 是一个clip函数,公式以下:框架

sigma(x)的定义

论文中提到,第二种方法彷佛更加合理,但它也引入了按几率分布的随机比特数,所以硬件实现会消耗不少时间,因此一般会选定第一种方法来对权重和输出值进行量化。

2.2 如何反向传播?

将CNN的权重和输出值二值化之后,梯度信息应当怎么办呢?论文指出梯度仍然不得不用较高精度的实数来存储,由于梯度很小,因此没法使用低精度来正确表达梯度,同时梯度是有高斯白噪声的,累加梯度才能抵消噪声。另外,二值化至关于给权重和输出值添加了噪声,这种噪声具备正则化的做用,能够防止模型过拟合,即它可让权重更加稀疏。

因为 s i g n sign 函数的导数在非0处都是0,因此在梯度回传的时候使用tanh来替代sign进行求导,这里假设损失函数是 C C ,输入是 r r ,对 r r 作二值化可得:

q = s i g n ( r ) q=sign(r)

C C q q 的导数能够用 g q g_q 表示,那么 C C r r 的导数为:

g r = g q 1 r < = 1 g_r=g_q1_{|r|<=1}

其中 1 r < = 1 1_{|r|<=1} 是tanh的梯度,这样就能够进行梯度回传,而后就能够根据梯度不断优化并训练参数了。

这里须要注意的是咱们须要使用BatchNorm层,BN层最大的做用就是能够加速学习并减小权重尺度的影响,带来必定量的正则化并提升CNN的性能,可是BN设计了不少的矩阵运算会下降运算速度。所以,论文提出了一种Shift-based Batch Normalization(SBN)层。SBN的优势是几乎不须要矩阵运算,而且不会带来性能损失。SBN的操做过程以下:

SBN层

这个函数实现了在不使用乘法的状况下近似计算BN,能够提升计算效率。

一样也是为了加速二值网络的训练,改进了AdaMax优化器。具体算法以下图所示。

改进了AdaMax优化器

2.3 第一层怎么办?

因为网络除了输入之外,所有都是二值化的,因此须要对第一层进行处理,将其二值化,整个二值化网络的处理流程以下:

二值化处理过程示意图

整个过程能够表示为:初始化第一层->计算前一层点积的Xnor->计算BatchNorm的符号->执行网络到倒数第二层->计算输出…

以上是假设输入的每一个数字只有8位的状况,若是咱们但愿使用任意 n n 位的整数,那么咱们能够对公式进行推广,即:

L i n e a r Q u a n t ( x , b i t w i d t h ) = c l i p ( r o u n d ( x b i t w i d t h ) × b i t w i d t h , m i n V , m a x V ) LinearQuant(x,bitwidth)=clip(round(\frac{x}{bitwidth})\times bitwidth, minV, maxV)

或者

L o g Q u a n t ( x , b i t w i d t h ) = c l i p ( A P 2 ( x ) , m i n V , m a x V ) LogQuant(x,bitwidth)=clip(AP2(x), minV, maxV)

3. 代码实现

接下来咱们来解析一下Pytorch实现一个BNN,须要注意的是代码实现和上面介绍的原理有不少不一样,首先第一个卷积层没有作二值化,也就是说第一个卷积层是普通的卷积层。对于输入也没有作定点化,即输入仍然为Float。另外,对于BN层和优化器也没有按照论文中的方法来作优化,代码地址以下:https://github.com/666DZY666/model-compression/blob/master/quantization/WbWtAb/models/nin.py

3.1 定义网络结构

下面的代码定义了支持权重和输出值分别可选二值或者三值量化,能够看到核心函数即为Conv2d_Q

import torch.nn as nn
from .util_wt_bab import Conv2d_Q

# *********************量化(三值、二值)卷积*********************
class Tnn_Bin_Conv2d(nn.Module):
    # 参数:last_relu-尾层卷积输入激活
    def __init__(self, input_channels, output_channels,
            kernel_size=-1, stride=-1, padding=-1, groups=1, last_relu=0, A=2, W=2):
        super(Tnn_Bin_Conv2d, self).__init__()
        self.A = A
        self.W = W
        self.last_relu = last_relu

        # ********************* 量化(三/二值)卷积 *********************
        self.tnn_bin_conv = Conv2d_Q(input_channels, output_channels,
                kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, A=A, W=W)
        self.bn = nn.BatchNorm2d(output_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.tnn_bin_conv(x)
        x = self.bn(x)
        if self.last_relu:
            x = self.relu(x)
        return x

class Net(nn.Module):
    def __init__(self, cfg = None, A=2, W=2):
        super(Net, self).__init__()
        # 模型结构与搭建
        if cfg is None:
            cfg = [192, 160, 96, 192, 192, 192, 192, 192]
        self.tnn_bin = nn.Sequential(
                nn.Conv2d(3, cfg[0], kernel_size=5, stride=1, padding=2),
                nn.BatchNorm2d(cfg[0]),
                Tnn_Bin_Conv2d(cfg[0], cfg[1], kernel_size=1, stride=1, padding=0, A=A, W=W),
                Tnn_Bin_Conv2d(cfg[1], cfg[2], kernel_size=1, stride=1, padding=0, A=A, W=W),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),

                Tnn_Bin_Conv2d(cfg[2], cfg[3], kernel_size=5, stride=1, padding=2, A=A, W=W),
                Tnn_Bin_Conv2d(cfg[3], cfg[4], kernel_size=1, stride=1, padding=0, A=A, W=W),
                Tnn_Bin_Conv2d(cfg[4], cfg[5], kernel_size=1, stride=1, padding=0, A=A, W=W),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),

                Tnn_Bin_Conv2d(cfg[5], cfg[6], kernel_size=3, stride=1, padding=1, A=A, W=W),
                Tnn_Bin_Conv2d(cfg[6], cfg[7], kernel_size=1, stride=1, padding=0, last_relu=1, A=A, W=W),
                nn.Conv2d(cfg[7],  10, kernel_size=1, stride=1, padding=0),
                nn.BatchNorm2d(10),
                nn.ReLU(inplace=True),
                nn.AvgPool2d(kernel_size=8, stride=1, padding=0),
                )

    def forward(self, x):
        x = self.tnn_bin(x)
        x = x.view(x.size(0), -1)
        return x

3.2 具体实现

咱们跟进一下Conv2d_Q函数,来看一下二值化的具体代码实现,注意我将代码里面和三值化有关的部分省略了。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function

# ********************* 二值(+-1) ***********************
# A 对激活值进行二值化的具体实现,原理中的第一个公式
class Binary_a(Function):

    @staticmethod
    def forward(self, input):
        self.save_for_backward(input)
        output = torch.sign(input)
        return output

    @staticmethod
    def backward(self, grad_output):
        input, = self.saved_tensors
        #*******************ste*********************
        grad_input = grad_output.clone()
        #****************saturate_ste***************
        grad_input[input.ge(1)] = 0
        grad_input[input.le(-1)] = 0
        ''' #******************soft_ste***************** size = input.size() zeros = torch.zeros(size).cuda() grad = torch.max(zeros, 1 - torch.abs(input)) #print(grad) grad_input = grad_output * grad '''
        return grad_input
# W 对权重进行二值化的具体实现
class Binary_w(Function):

    @staticmethod
    def forward(self, input):
        output = torch.sign(input)
        return output

    @staticmethod
    def backward(self, grad_output):
        #*******************ste*********************
        grad_input = grad_output.clone()
        return grad_input

# ********************* A(特征)量化(二值) ***********************
# 由于咱们使用的网络结构不是彻底的二值化,第一个卷积层是普通卷积接的ReLU激活函数,因此要判断一下
class activation_bin(nn.Module):
  def __init__(self, A):
    super().__init__()
    self.A = A
    self.relu = nn.ReLU(inplace=True)

  def binary(self, input):
    output = Binary_a.apply(input)
    return output

  def forward(self, input):
    if self.A == 2:
      output = self.binary(input)
      # ******************** A —— 一、0 *********************
      #a = torch.clamp(a, min=0)
    else:
      output = self.relu(input)
    return output
# ********************* W(模型参数)量化(三/二值) ***********************
def meancenter_clampConvParams(w):
    mean = w.data.mean(1, keepdim=True)
    w.data.sub(mean) # W中心化(C方向)
    w.data.clamp(-1.0, 1.0) # W截断
    return w
# 对激活值进行二值化
class weight_tnn_bin(nn.Module):
  def __init__(self, W):
    super().__init__()
    self.W = W

  def binary(self, input):
    output = Binary_w.apply(input)
    return output

  def forward(self, input):
   
        # **************************************** W二值 *****************************************
       output = meancenter_clampConvParams(input) # W中心化+截断
        # **************** channel级 - E(|W|) ****************
        E = torch.mean(torch.abs(output), (3, 2, 1), keepdim=True)
        # **************** α(缩放因子) ****************
        alpha = E
        # ************** W —— +-1 **************
        output = self.binary(output)
        # ************** W * α **************
        output = output * alpha # 若不须要α(缩放因子),注释掉便可
        # **************************************** W三值 *****************************************
    else:
      output = input
    return output

# ********************* 量化卷积(同时量化A/W,并作卷积) ***********************
class Conv2d_Q(nn.Conv2d):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        dilation=1,
        groups=1,
        bias=True,
        A=2,
        W=2
      ):
        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias
        )
        # 实例化调用A和W量化器
        self.activation_quantizer = activation_bin(A=A)
        self.weight_quantizer = weight_tnn_bin(W=W)
          
    def forward(self, input):
        # 量化A和W
        bin_input = self.activation_quantizer(input)
        tnn_bin_weight = self.weight_quantizer(self.weight)    
        #print(bin_input)
        #print(tnn_bin_weight)
        # 用量化后的A和W作卷积
        output = F.conv2d(
            input=bin_input, 
            weight=tnn_bin_weight, 
            bias=self.bias, 
            stride=self.stride, 
            padding=self.padding, 
            dilation=self.dilation, 
            groups=self.groups)
        return output

上面的代码比较好理解,由于它将BNN论文中最难实现的SBN和改进后的AdaMax优化算法省略掉了,而且没有对输入进行定电化,因此编码难度小了不少,这个代码能够验证一下使用BNN以后精度损失。

4. 实验结果

这里贴一下使用上面的网络训练Cifar10图像分类数据集的准确率对比:

试验结果对比

能够看到若是将除了第一层卷积以外的卷积层均换成二值化卷积以后,模型的压缩率高达92%而且准确率也只有1个点的降低,这说明在Cifar10数据集上这种方法确实是有效的。笔者跑了一下这个代码,测试结果和代码做者是相似的。

5. 思考

咱们看一下论文给出的BNN在MNIST/CIFAR-10等数据集上的测试结果:

二值化网络性能测试

能够看到这些简单网络的分类偏差还在可接受的范围以内,但这种二值化网络在ImageNet上的测试结果倒是比较差的,出现了很大的偏差。虽然还存在不少的优化技巧好比放开Tanh的边界,用2Bit的激活函数能够提高一些精度,但在复杂模型下效果仍然不太好。所以,二值化模型的最大缺点应该是不适合复杂模型。另外,新定义的算子在部署时也是一件麻烦的事,还须要专门的推理框架或者定制的硬件来支持。否则就只能像咱们介绍的代码实现那样,使用矩阵乘法来模拟这个二值化计算过程,但加速是很是有限的。


欢迎关注GiantPandaCV, 在这里你将看到独家的深度学习分享,坚持原创,天天分享咱们学习到的新鲜知识。( • ̀ω•́ )✧

有对文章相关的问题,或者想要加入交流群,欢迎添加BBuf微信:

二维码
为了方便读者获取资料以及咱们公众号的做者发布一些Github工程的更新,咱们成立了一个QQ群,二维码以下,感兴趣能够加入。

公众号QQ交流群