如今深度学习中通常咱们学习的参数都是连续的,由于这样在反向传播的时候才能够对梯度进行更新。可是有的时候咱们也会遇到参数是离>散的状况,这样就没有办法进行反向传播了,好比二值神经网络。本文中讲解了如何用pytorch
对二值化的参数进行梯度更新的straight-through estimator
算法。html
STE
核心的思想就是咱们的参数初始化的时候就是float
这样的连续值,当咱们forward
的时候就将原来的连续的参数映射到{-1,, 1}带入到网络进行计算,这样就能够计算网络的输出。而后backward
的时候直接对原来float
的参数进行更新,而不是对二值化的参数更新。这样能够完成对整个网络的更新了。
首先咱们对上面问题进行一下数学的讲解。git
torch.sign
函数, 能够理解为取符号函数backward
的过程当中对$q$求梯度可得 $\frac{\partial loss}{\partial q}$loss
对r
梯度都是0backward
的过程咱们须要修改$\frac{\partial q}{\partial r}$这部分才可使梯度继续更新下去,因此对$\frac{\partial loss}{\partial r}$进行以下修改: $\frac{\partial q}{\partial r} = \frac{\partial loss}{\partial q} * 1\_{|r| \leq 1}$, 其中$1\_{|r| \leq 1}$ 能够看做$Htanh(x) = Clip(x, -1, 1) = max(-1, min(1, x))$对$x$的求导过程, 也就是是说:
$$\frac{\partial loss}{\partial r} = \frac{\partial loss}{\partial q} \frac{\partial Htanh}{\partial r}$$github
首先咱们验证一下使用torch.sign
会是参数的梯度基本上都是0:算法
>>> input = torch.randn(4, requires_grad = True) >>> output = torch.sign(input) >>> loss = output.mean() >>> loss.backward() >>> input tensor([-0.8673, -0.0299, -1.1434, -0.6172], requires_grad=True) >>> input.grad tensor([0., 0., 0., 0.])
咱们须要重写sign
这个函数,就好像写一个激活函数同样。先看一下代码, github源码:LBSign.py
网络
import torch class LBSign(torch.autograd.Function): @staticmethod def forward(ctx, input): return torch.sign(input) @staticmethod def backward(ctx, grad_output): return grad_output.clamp_(-1, 1)
接下来咱们作一下测试main.py
app
import torch from LBSign import LBSign if __name__ == '__main__': sign = LBSign.apply params = torch.randn(4, requires_grad = True) output = sign(params) loss = output.mean() loss.backward()
而后咱们发现有梯度了函数
>>> params tensor([-0.9143, 0.8993, -1.1235, -0.7928], requires_grad=True) >>> params.grad tensor([0.2500, 0.2500, 0.2500, 0.2500])
接下来咱们对代码就行一下解释pytorch文档连接:学习
ctx
是保存的上下文信息,input
是输入ctx
是保存的上下文信息,grad_output
能够理解成 $\frac{\partial loss}{\partial q}$这一步的梯度信息,咱们须要作的就是让
$$grad\_output * \frac{\partial Htanh}{\partial r}$$ 而不是让pytorch
继续默认的 $$grad\_output * \frac{\partial q}{\partial r}$$
可是咱们能够从上面的公式能够看出函数$Htanh$对$x$求导是1, 当$x \in [-1, 1]$,因此程序就能够化简成保留原来的梯度就好了,而后裁剪到其余范围的。测试
torch.autograd.Function
Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1
二值网络,围绕STE的那些事儿
Custom binarization layer with straight through estimator gives error
定义torch.autograd.Function的子类,本身定义某些操做,且定义反向求导函数ui