从零开始学习MXnet(五)MXnet的黑科技之显存节省大法

  写完发现名字有点拗口。。- -#git

  你们在作deep learning的时候,应该都遇到过显存不够用,而后不得不去痛苦的减去batchszie,或者砍本身的网络结构呢? 最后跑出来的效果不尽如人意,总以为本身被全世界针对了。。遇到这种状况怎么办? 请使用MXnet的天奇大法带你省显存! 鲁迅曾经说过:你不去试试,怎么会知道本身的idea真的是这么糟糕呢?github

  首先是传送门附上 mxnet-memonger,相应的paper也是值得一看的 Training Deep Nets with Sublinear Memory Cost网络

  实际上repo和paer里面都说的很清楚了,这里简单提一下吧。ide

  1、Why?idea

  节省显存的原理是什么呢?咱们知道,咱们在训练一个网络的时候,显存是用来保存中间的结果的,为何须要保存中间的结果呢,由于在BP算梯度的时候,咱们是须要当前层的值和上一层回传的梯度一块儿才能计算获得的,因此这看来显存是没法节省的?固然不会,简单的举个例子:一个3层的神经网络,咱们能够不保存第二层的结果,在BP到第二层须要它的结果的时候,能够经过第一层的结果来计算出来,这样就节省了很多内存。  提醒一下,这只是我我的的理解,事实上这篇paper一直没有去好好的读一下,有时间在再个笔记。不过大致的意思差很少就是这样。spa

  

  2、How?code

  怎么作呢?分享一下个人trick吧,我通常会在symbol的相加的地方如data = data+ data0这种后面加上一行 data._set_attr(force_mirroring='True'),为何这么作你们能够去看看repo的readme,symbol的地方处理完之后,只有以下就能够了,searchplan会返回一个能够节省显存的的symbol给你,其它地方彻底同样。orm

  

 1 import mxnet as mx
 2 import memonger
 3 
 4 # configure your network
 5 net = my_symbol()
 6 
 7 # call memory optimizer to search possible memory plan.
 8 net_planned = memonger.search_plan(net)
 9 
10 # use as normal
11 model = mx.FeedForward(net_planned, ...)
12 model.fit(...)

  PS:使用的时候要注意,千万不要在又随机性的层例如dropout后面加上mirror,由于这个结果,再算一次就和上一次不一样了,会让你的symbol的loss变得很奇怪。。blog

 

3、总结内存

天奇大法吼啊!

相关文章
相关标签/搜索