[Pytorch]深度模型的显存计算以及优化

原文连接:https://oldpan.me/archives/how-to-calculate-gpu-memoryphp

《浅谈深度学习:如何计算模型以及中间变量的显存占用大小》前言

亲,显存炸了,你的显卡快冒烟了!html

torch.FatalError: cuda runtime error (2) : out of memory at /opt/conda/conda-bld/pytorch_1524590031827/work/aten/src/THC/generic/THCStorage.cu:58

想必这是全部炼丹师们最不想看到的错误,没有之一。python

OUT OF MEMORY,显然是显存装不下你那么多的模型权重还有中间变量,而后程序奔溃了。怎么办,其实办法有不少,及时清空中间变量,优化代码,减小batch,等等等等,都可以减小显存溢出的风险。git

可是这篇要说的是上面这一切优化操做的基础,如何去计算咱们所使用的显存。学会如何计算出来咱们设计的模型以及中间变量所占显存的大小,想必知道了这一点,咱们对本身显存也就会驾轻就熟了。github

如何计算

首先咱们应该了解一下基本的数据量信息:算法

  • 1 G = 1000 MB
  • 1 M = 1000 KB
  • 1 K = 1000 Byte
  • 1 B = 8 bit

好,确定有人会问为何是1000而不是1024,这里不过多讨论,只能说两种说法都是正确的,只是应用场景略有不一样。这里统一按照上面的标准进行计算。bash

而后咱们说一下咱们日常使用的向量所占的空间大小,以Pytorch官方的数据格式为例(全部的深度学习框架数据格式都遵循同一个标准):服务器

《浅谈深度学习:如何计算模型以及中间变量的显存占用大小》

咱们只须要看左边的信息,在日常的训练中,咱们常用的通常是这两种类型:markdown

  • float32 单精度浮点型
  • int32 整型

通常一个8-bit的整型变量所占的空间为1B也就是8bit。而32位的float则占4B也就是32bit。而双精度浮点型double和长整型long在日常的训练中咱们通常不会使用。网络

ps:消费级显卡对单精度计算有优化,服务器级别显卡对双精度计算有优化。

也就是说,假设有一幅RGB三通道真彩色图片,长宽分别为500 x 500,数据类型为单精度浮点型,那么这张图所占的显存的大小为:500 x 500 x 3 x 4B = 3M。

而一个(256,3,100,100)-(N,C,H,W)的FloatTensor所占的空间为256 x 3 x 100 x 100 x 4B = 31M

很少是吧,不要紧,好戏才刚刚开始。

显存去哪儿了

看起来一张图片(3x256x256)和卷积层(256x100x100)所占的空间并不大,那为何咱们的显存依旧仍是用的比较多,缘由很简单,占用显存比较多空间的并非咱们输入图像,而是神经网络中的中间变量以及使用optimizer算法时产生的巨量的中间参数

咱们首先来简单计算一下Vgg16这个net须要占用的显存:

一般一个模型占用的显存也就是两部分:

  • 模型自身的参数(params)
  • 模型计算产生的中间变量(memory)

《浅谈深度学习:如何计算模型以及中间变量的显存占用大小》

图片来自cs231n,这是一个典型的sequential-net,自上而下很顺畅,咱们能够看到咱们输入的是一张224x224x3的三通道图像,能够看到一张图像只占用150x4k,但上面标注的是150k,这是由于上图中在计算的时候默认的数据格式是8-bit而不是32-bit,因此最后的结果要乘上一个4。

咱们能够看到,左边的memory值表明:图像输入进去,图片以及所产生的中间卷积层所占的空间。咱们都知道,这些形形色色的深层卷积层也就是深度神经网络进行“思考”的过程:

《浅谈深度学习:如何计算模型以及中间变量的显存占用大小》

图片从3通道变为64 –> 128 –> 256 –> 512 …. 这些都是卷积层,而咱们的显存也主要是他们占用了。

还有上面右边的params,这些是神经网络的权重大小,能够看到第一层卷积是3×3,而输入图像的通道是3,输出通道是64,因此很显然,第一个卷积层权重所占的空间是 (3 x 3 x 3) x 64。

另外还有一个须要注意的是中间变量在backward的时候会翻倍!

为何,举个例子,下面是一个计算图,输入x,通过中间结果z,而后获得最终变量L

《浅谈深度学习:如何计算模型以及中间变量的显存占用大小》

咱们在backward的时候须要保存下来的中间值。输出是L,而后输入x,咱们在backward的时候要求Lx的梯度,这个时候就须要在计算链Lx中间的z

《浅谈深度学习:如何计算模型以及中间变量的显存占用大小》

dz/dx这个中间值固然要保留下来以用于计算,因此粗略估计,backward的时候中间变量的占用了是forward的两倍!

优化器和动量

要注意,优化器也会占用咱们的显存!

为何,看这个式子:

《浅谈深度学习:如何计算模型以及中间变量的显存占用大小》

上式是典型的SGD随机降低法的整体公式,权重W在进行更新的时候,会产生保存中间变量《浅谈深度学习:如何计算模型以及中间变量的显存占用大小》,也就是在优化的时候,模型中的params参数所占用的显存量会翻倍。

固然这只是SGD优化器,其余复杂的优化器若是在计算时须要的中间变量多的时候,就会占用更多的内存。

模型中哪些层会占用显存

有参数的层即会占用显存的层。咱们通常的卷积层都会占用显存,而咱们常用的激活层Relu没有参数就不会占用了。

占用显存的层通常是:

  • 卷积层,一般的conv2d
  • 全链接层,也就是Linear层
  • BatchNorm层
  • Embedding层

而不占用显存的则是:

  • 刚才说到的激活层Relu等
  • 池化层
  • Dropout层

具体计算方式:

  • Conv2d(Cin, Cout, K): 参数数目:Cin × Cout × K × K
  • Linear(M->N): 参数数目:M×N
  • BatchNorm(N): 参数数目: 2N
  • Embedding(N,W): 参数数目: N × W

额外的显存

总结一下,咱们在整体的训练中,占用显存大概分如下几类:

  • 模型中的参数(卷积层或其余有参数的层)
  • 模型在计算时产生的中间参数(也就是输入图像在计算时每一层产生的输入和输出)
  • backward的时候产生的额外的中间参数
  • 优化器在优化时产生的额外的模型参数

但其实,咱们占用的显存空间为何比咱们理论计算的还要大,缘由大概是由于深度学习框架一些额外的开销吧,不过若是经过上面公式,理论计算出来的显存和实际不会差太多的。

如何优化

优化除了算法层的优化,最基本的优化无非也就一下几点:

  • 减小输入图像的尺寸
  • 减小batch,减小每次的输入图像数量
  • 多使用下采样,池化层
  • 一些神经网络层能够进行小优化,利用relu层中设置inplace
  • 购买显存更大的显卡
  • 从深度学习框架上面进行优化

下篇文章我会说明如何在Pytorch这个深度学习框架中跟踪显存的使用量,而后针对Pytorch这个框架进行有目的显存优化。

参考:
https://blog.csdn.net/liusandian/article/details/79069926

《浅谈深度学习:如何计算模型以及中间变量的显存占用大小》

原文连接:https://ptorch.com/news/181.html


前言


在上篇文章《浅谈深度学习:如何计算模型以及中间变量的显存占用大小》中咱们对如何计算各类变量所占显存大小进行了一些探索。而这篇文章咱们着重讲解如何利用Pytorch深度学习框架的一些特性,去查看咱们当前使用的变量所占用的显存大小,以及一些优化工做。如下代码所使用的平台框架为Pytorch。


优化显存


在Pytorch中优化显存是咱们处理大量数据时必要的作法,由于咱们并不可能拥有无限的显存。显存是有限的,而数据是无限的,咱们只有优化显存的使用量才可以最大化地利用咱们的数据,实现多种多样的算法。


估测模型所占的内存


上篇文章中说过,一个模型所占的显存无非是这两种:



  • 模型权重参数

  • 模型所储存的中间变量


其实权重参数通常来讲并不会占用不少的显存空间,主要占用显存空间的仍是计算时产生的中间变量,当咱们定义了一个model以后,咱们能够经过如下代码简单计算出这个模型权重参数所占用的数据量:


import numpy as np

# model是咱们在pytorch定义的神经网络层
# model.parameters()取出这个model全部的权重参数
para = sum([np.prod(list(p.size())) for p in model.parameters()])

假设咱们有这样一个model:


Sequential(
(conv_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu_1): ReLU(inplace)
(conv_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu_2): ReLU(inplace)
(pool_2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv_3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

而后咱们获得的para112576,可是咱们计算出来的仅仅是权重参数的“数量”,单位是B,咱们须要转化一下:


# 下面的type_size是4,由于咱们的参数是float32也就是4B,4个字节
print('Model {} : params: {:4f}M'.format(model._get_name(), para * type_size / 1000 / 1000))

这样就能够打印出:


Model Sequential : params: 0.450304M

可是咱们以前说过一个神经网络的模型,不只仅有权重参数还要计算中间变量的大小。怎么去计算,咱们能够假设一个输入变量,而后将这个输入变量投入这个模型中,而后咱们主动提取这些计算出来的中间变量:


# model是咱们加载的模型
# input是实际中投入的input(Tensor)变量

# 利用clone()去复制一个input,这样不会对input形成影响
input_ = input.clone()
# 确保不须要计算梯度,由于咱们的目的只是为了计算中间变量而已
input_.requires_grad_(requires_grad=False)

mods = list(model.modules())
out_sizes = []

for i in range(1, len(mods)):
m
= mods[i]
# 注意这里,若是relu激活函数是inplace则不用计算
if isinstance(m, nn.ReLU):
if m.inplace:
continue
out
= m(input_)
out_sizes.append(np.array(out.size()))
input_ = out

total_nums = 0
for i in range(len(out_sizes)):
s
= out_sizes[i]
nums = np.prod(np.array(s))
total_nums += nums

上面获得的值是模型在运行时候产生全部的中间变量的“数量”,固然咱们须要换算一下:


# 打印两种,只有 forward 和 foreward、backward的状况
print('Model {} : intermedite variables: {:3f} M (without backward)'
.format(model._get_name(), total_nums * type_size / 1000 / 1000))
print('Model {} : intermedite variables: {:3f} M (with backward)'
.format(model._get_name(), total_nums * type_size*2 / 1000 / 1000))

由于在backward的时候全部的中间变量须要保存下来再来进行计算,因此咱们在计算backward的时候,计算出来的中间变量须要乘个2。


而后咱们得出,上面这个模型的中间变量须要的占用的显存,很显然,中间变量占用的值比模型自己的权重值多多了。若是进行一次backward那么须要的就更多。


Model Sequential : intermedite variables: 336.089600 M (without backward)
Model Sequential : intermedite variables: 672.179200 M (with backward)

咱们总结一下以前的代码:


# 模型显存占用监测函数
# model:输入的模型
# input:实际中须要输入的Tensor变量
# type_size 默认为 4 默认类型为 float32

def modelsize(model, input, type_size=4):
para = sum([np.prod(list(p.size())) for p in model.parameters()])
print('Model {} : params: {:4f}M'.format(model._get_name(), para * type_size / 1000 / 1000))

input_ = input.clone()
input_.requires_grad_(requires_grad=<span class="hljs-keyword">False</span>)

mods = list(model.modules())
out_sizes = []

<span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(<span class="hljs-number">1</span>, len(mods)):
    m = mods[i]
    <span class="hljs-keyword">if</span> isinstance(m, nn.ReLU):
        <span class="hljs-keyword">if</span> m.inplace:
            <span class="hljs-keyword">continue</span>
    out = m(input_)
    out_sizes.append(np.array(out.size()))
    input_ = out

total_nums = <span class="hljs-number">0</span>
<span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(len(out_sizes)):
    s = out_sizes[i]
    nums = np.prod(np.array(s))
    total_nums += nums

print(<span class="hljs-string">'Model {} : intermedite variables: {:3f} M (without backward)'</span>
      .format(model._get_name(), total_nums * type_size / <span class="hljs-number">1000</span> / <span class="hljs-number">1000</span>))
print(<span class="hljs-string">'Model {} : intermedite variables: {:3f} M (with backward)'</span>
      .format(model._get_name(), total_nums * type_size*<span class="hljs-number">2</span> / <span class="hljs-number">1000</span> / <span class="hljs-number">1000</span>))
input_ = input.clone() input_.requires_grad_(requires_grad=<span class="hljs-keyword">False</span>) mods = list(model.modules()) out_sizes = [] <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(<span class="hljs-number">1</span>, len(mods)): m = mods[i] <span class="hljs-keyword">if</span> isinstance(m, nn.ReLU): <span class="hljs-keyword">if</span> m.inplace: <span class="hljs-keyword">continue</span> out = m(input_) out_sizes.append(np.array(out.size())) input_ = out total_nums = <span class="hljs-number">0</span> <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(len(out_sizes)): s = out_sizes[i] nums = np.prod(np.array(s)) total_nums += nums print(<span class="hljs-string">'Model {} : intermedite variables: {:3f} M (without backward)'</span> .format(model._get_name(), total_nums * type_size / <span class="hljs-number">1000</span> / <span class="hljs-number">1000</span>)) print(<span class="hljs-string">'Model {} : intermedite variables: {:3f} M (with backward)'</span> .format(model._get_name(), total_nums * type_size*<span class="hljs-number">2</span> / <span class="hljs-number">1000</span> / <span class="hljs-number">1000</span>))

固然咱们计算出来的占用显存值仅仅是作参考做用,由于Pytorch在运行的时候须要额外的显存值开销,因此实际的显存会比咱们计算的稍微大一些。


关于inplace=False

咱们都知道激活函数Relu()有一个默认参数inplace,默认设置为False,当设置为True时,咱们在经过relu()计算时的获得的新值不会占用新的空间而是直接覆盖原来的值,这也就是为何当inplace参数设置为True时能够节省一部份内存的缘故。


《如何在Pytorch中精细化利用显存》


牺牲计算速度减小显存使用量


Pytorch-0.4.0出来了一个新的功能,能够将一个计算过程分红两半,也就是若是一个模型须要占用的显存太大了,咱们就能够先计算一半,保存后一半须要的中间结果,而后再计算后一半。


也就是说,新的checkpoint容许咱们只存储反向传播所须要的部份内容。若是当中缺乏一个输出(为了节省内存而致使的),checkpoint将会从最近的检查点从新计算中间输出,以便减小内存使用(固然计算时间增长了):


# 输入
input = torch.rand(1, 10)
# 假设咱们有一个很是深的网络
layers = [nn.Linear(10, 10) for _ in range(1000)]
model = nn.Sequential(*layers)
output = model(input)

上面的模型须要占用不少的内存,由于计算中会产生不少的中间变量。为此checkpoint就能够帮助咱们来节省内存的占用了。


# 首先设置输入的input=&gt;requires_grad=True
# 若是不设置可能会致使获得的gradient为0

input = torch.rand(1, 10, requires_grad=True)
layers = [nn.Linear(10, 10) for _ in range(1000)]

# 定义要计算的层函数,能够看到咱们定义了两个
# 一个计算前500个层,另外一个计算后500个层

def run_first_half(*args):
x = args[0]
for layer in layers[:500]:
x = layer(x)
return x

def run_second_half(*args):
x = args[0]
for layer in layers[500:-1]:
x = layer(x)
return x

# 咱们引入新加的checkpoint
from torch.utils.checkpoint import checkpoint

x = checkpoint(run_first_half, input)
x = checkpoint(run_second_half, x)
# 最后一层单独调出来执行
x = layers-1
x.sum.backward() # 这样就能够了

对于Sequential-model来讲,由于Sequential()中能够包含不少的block,因此官方提供了另外一个功能包:


input = torch.rand(1, 10, requires_grad=True)
layers = [nn.Linear(10, 10) for _ in range(1000)]
model = nn.Sequential(*layers)

from torch.utils.checkpoint import checkpoint_sequential

# 分红两个部分
num_segments = 2
x = checkpoint_sequential(model, num_segments, input)
x.sum().backward() # 这样就能够了

跟踪显存使用状况


显存的使用状况,在编写程序中咱们可能没法精确计算,可是咱们能够经过pynvml这个Nvidia的Python环境库和Python的垃圾回收工具,能够实时地打印咱们使用的显存以及哪些Tensor使用了咱们的显存。


相似于下面的报告:


# 08-Jun-18-17:56:51-gpu_mem_prof

At __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">39</span>                        Total Used Memory:<span class="hljs-number">399.4</span>  Mb
At __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">40</span>                        Total Used Memory:<span class="hljs-number">992.5</span>  Mb
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">40</span>                         (<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">682</span>, <span class="hljs-number">700</span>)     <span class="hljs-number">1.82</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.Tensor'</span>&amp;<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">40</span>                         (<span class="hljs-number">1</span>, <span class="hljs-number">3</span>, <span class="hljs-number">682</span>, <span class="hljs-number">700</span>)     <span class="hljs-number">5.46</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.Tensor'</span>&amp;<span class="hljs-keyword">gt</span>;
At __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span>                       Total Used Memory:<span class="hljs-number">1088.5</span> Mb
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span>                        (<span class="hljs-number">64</span>, <span class="hljs-number">64</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>)       <span class="hljs-number">0</span>.<span class="hljs-number">14</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span>                        (<span class="hljs-number">128</span>, <span class="hljs-number">64</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>)      <span class="hljs-number">0</span>.<span class="hljs-number">28</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span>                        (<span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>)     <span class="hljs-number">0</span>.<span class="hljs-number">56</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span>                        (<span class="hljs-number">64</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>)        <span class="hljs-number">0</span>.<span class="hljs-number">00</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span>                        (<span class="hljs-number">256</span>, <span class="hljs-number">256</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>)     <span class="hljs-number">2.25</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span>                        (<span class="hljs-number">512</span>, <span class="hljs-number">256</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>)     <span class="hljs-number">4.5</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span>                        (<span class="hljs-number">512</span>, <span class="hljs-number">512</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>)     <span class="hljs-number">9.0</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span>                        (<span class="hljs-number">64</span>,)                <span class="hljs-number">0</span>.<span class="hljs-number">00</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span>                        (<span class="hljs-number">1</span>, <span class="hljs-number">3</span>, <span class="hljs-number">682</span>, <span class="hljs-number">700</span>)     <span class="hljs-number">5.46</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.Tensor'</span>&amp;<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span>                        (<span class="hljs-number">128</span>,)               <span class="hljs-number">0</span>.<span class="hljs-number">00</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span>                        (<span class="hljs-number">256</span>,)               <span class="hljs-number">0</span>.<span class="hljs-number">00</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span>                        (<span class="hljs-number">512</span>,)               <span class="hljs-number">0</span>.<span class="hljs-number">00</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span>                        (<span class="hljs-number">3</span>,)                 <span class="hljs-number">1.14</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.Tensor'</span>&amp;<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span>                        (<span class="hljs-number">256</span>, <span class="hljs-number">128</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>)     <span class="hljs-number">1.12</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>;
...</code></pre>

如下是相关的代码,目前代码依然有些地方须要修改,等修改完善好我会将完整代码以及使用说明放到github上:https://github.com/Oldpan/Pytorch-Memory-Utils 请你们多多留意。

import datetime
import linecache
import os

import gc
import pynvml
import torch
import numpy as np

print_tensor_sizes = True
last_tensor_sizes = set()
gpu_profile_fn = f'{datetime.datetime.now():%d-%b-%y-%H:%M:%S}-gpu_mem_prof.txt'

# if 'GPU_DEBUG' in os.environ:
# print('profiling gpu usage to ', gpu_profile_fn)

lineno = None
func_name = None
filename = None
module_name = None

# fram = inspect.currentframe()
# func_name = fram.f_code.co_name
# filename = fram.f_globals["__file__"]
# ss = os.path.dirname(os.path.abspath(filename))
# module_name = fram.f_globals["__name__"]

def gpu_profile(frame, event):
    # it is _about to_ execute (!)
    global last_tensor_sizes
    global lineno, func_name, filename, module_name

    if event == 'line':
        try:
            # about _previous_ line (!)
            if lineno is not None:
                pynvml.nvmlInit()
                # handle = pynvml.nvmlDeviceGetHandleByIndex(int(os.environ['GPU_DEBUG']))
                handle = pynvml.nvmlDeviceGetHandleByIndex(0)
                meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
                line = linecache.getline(filename, lineno)
                where_str = module_name+' '+func_name+':'+' line '+str(lineno)

                with open(gpu_profile_fn, 'a+') as f:
                    f.write(f"At {where_str:&lt;50}"
                            f"Total Used Memory:{meminfo.used/1024**2:&lt;7.1f}Mb\n")

                    if print_tensor_sizes is True:
                        for tensor in get_tensors():
                            if not hasattr(tensor, 'dbg_alloc_where'):
                                tensor.dbg_alloc_where = where_str
                        new_tensor_sizes = {(type(x), tuple(x.size()), np.prod(np.array(x.size()))*4/1024**2,
                                             x.dbg_alloc_where) for x in get_tensors()}
                        for t, s, m, loc in new_tensor_sizes - last_tensor_sizes:
                            f.write(f'+ {loc:&lt;50} {str(s):&lt;20} {str(m)[:4]} M {str(t):&lt;10}\n')
                        for t, s, m, loc in last_tensor_sizes - new_tensor_sizes:
                            f.write(f'- {loc:&lt;50} {str(s):&lt;20} {str(m)[:4]} M {str(t):&lt;10}\n')
                        last_tensor_sizes = new_tensor_sizes
                pynvml.nvmlShutdown()

            # save details about line _to be_ executed
            lineno = None

            func_name = frame.f_code.co_name
            filename = frame.f_globals["__file__"]
            if (filename.endswith(".pyc") or
                    filename.endswith(".pyo")):
                filename = filename[:-1]
            module_name = frame.f_globals["__name__"]
            lineno = frame.f_lineno

            return gpu_profile

        except Exception as e:
            print('A exception occured: {}'.format(e))

    return gpu_profile

def get_tensors():
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj):
                tensor = obj
            else:
                continue
            if tensor.is_cuda:
                yield tensor
        except Exception as e:
            print('A exception occured: {}'.format(e))

须要注意的是,linecache中的getlines只能读取缓冲过的文件,若是这个文件没有运行过则返回无效值。Python 的垃圾收集机制会在变量没有应引用的时候立马进行回收,可是为何模型中计算的中间变量在执行结束后还会存在呢。既然都没有引用了为何还会占用空间?

一种可能的状况是这些引用不在Python代码中,而是在神经网络层的运行中为了backward被保存为gradient,这些引用都在计算图中,咱们在程序中是没法看到的:

《如何在Pytorch中精细化利用显存》

后记

实际中咱们会有些只使用一次的模型,为了节省显存,咱们须要一边计算一遍清除中间变量,使用del进行操做。限于篇幅这里不进行讲解,下一篇会进行说明。

原文地址:如何在Pytorch中精细化利用显存

        <br>
        原创文章,转载请注明 :<a href="https://ptorch.com/news/181.html" target="_blank">如何在Pytorch中精细化利用显存以及提升Pytorch显存利用率 - pytorch中文网</a><br>
        原文出处:   https://ptorch.com/news/181.html<br>
        问题交流群 :168117787
    </div>
At __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">39</span> Total Used Memory:<span class="hljs-number">399.4</span> Mb At __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">40</span> Total Used Memory:<span class="hljs-number">992.5</span> Mb + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">40</span> (<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">682</span>, <span class="hljs-number">700</span>) <span class="hljs-number">1.82</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.Tensor'</span>&amp;<span class="hljs-keyword">gt</span>; + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">40</span> (<span class="hljs-number">1</span>, <span class="hljs-number">3</span>, <span class="hljs-number">682</span>, <span class="hljs-number">700</span>) <span class="hljs-number">5.46</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.Tensor'</span>&amp;<span class="hljs-keyword">gt</span>; At __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> Total Used Memory:<span class="hljs-number">1088.5</span> Mb + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">64</span>, <span class="hljs-number">64</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>) <span class="hljs-number">0</span>.<span class="hljs-number">14</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>; + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">128</span>, <span class="hljs-number">64</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>) <span class="hljs-number">0</span>.<span class="hljs-number">28</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>; + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>) <span class="hljs-number">0</span>.<span class="hljs-number">56</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>; + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">64</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>) <span class="hljs-number">0</span>.<span class="hljs-number">00</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>; + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">256</span>, <span class="hljs-number">256</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>) <span class="hljs-number">2.25</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>; + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">512</span>, <span class="hljs-number">256</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>) <span class="hljs-number">4.5</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>; + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">512</span>, <span class="hljs-number">512</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>) <span class="hljs-number">9.0</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>; + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">64</span>,) <span class="hljs-number">0</span>.<span class="hljs-number">00</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>; + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">1</span>, <span class="hljs-number">3</span>, <span class="hljs-number">682</span>, <span class="hljs-number">700</span>) <span class="hljs-number">5.46</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.Tensor'</span>&amp;<span class="hljs-keyword">gt</span>; + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">128</span>,) <span class="hljs-number">0</span>.<span class="hljs-number">00</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>; + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">256</span>,) <span class="hljs-number">0</span>.<span class="hljs-number">00</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>; + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">512</span>,) <span class="hljs-number">0</span>.<span class="hljs-number">00</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>; + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">3</span>,) <span class="hljs-number">1.14</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.Tensor'</span>&amp;<span class="hljs-keyword">gt</span>; + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">256</span>, <span class="hljs-number">128</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>) <span class="hljs-number">1.12</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>; ...</code></pre>import datetime import linecache import os import gc import pynvml import torch import numpy as np print_tensor_sizes = True last_tensor_sizes = set() gpu_profile_fn = f'{datetime.datetime.now():%d-%b-%y-%H:%M:%S}-gpu_mem_prof.txt' # if 'GPU_DEBUG' in os.environ: # print('profiling gpu usage to ', gpu_profile_fn) lineno = None func_name = None filename = None module_name = None # fram = inspect.currentframe() # func_name = fram.f_code.co_name # filename = fram.f_globals["__file__"] # ss = os.path.dirname(os.path.abspath(filename)) # module_name = fram.f_globals["__name__"] def gpu_profile(frame, event): # it is _about to_ execute (!) global last_tensor_sizes global lineno, func_name, filename, module_name if event == 'line': try: # about _previous_ line (!) if lineno is not None: pynvml.nvmlInit() # handle = pynvml.nvmlDeviceGetHandleByIndex(int(os.environ['GPU_DEBUG'])) handle = pynvml.nvmlDeviceGetHandleByIndex(0) meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle) line = linecache.getline(filename, lineno) where_str = module_name+' '+func_name+':'+' line '+str(lineno) with open(gpu_profile_fn, 'a+') as f: f.write(f"At {where_str:&lt;50}" f"Total Used Memory:{meminfo.used/1024**2:&lt;7.1f}Mb\n") if print_tensor_sizes is True: for tensor in get_tensors(): if not hasattr(tensor, 'dbg_alloc_where'): tensor.dbg_alloc_where = where_str new_tensor_sizes = {(type(x), tuple(x.size()), np.prod(np.array(x.size()))*4/1024**2, x.dbg_alloc_where) for x in get_tensors()} for t, s, m, loc in new_tensor_sizes - last_tensor_sizes: f.write(f'+ {loc:&lt;50} {str(s):&lt;20} {str(m)[:4]} M {str(t):&lt;10}\n') for t, s, m, loc in last_tensor_sizes - new_tensor_sizes: f.write(f'- {loc:&lt;50} {str(s):&lt;20} {str(m)[:4]} M {str(t):&lt;10}\n') last_tensor_sizes = new_tensor_sizes pynvml.nvmlShutdown() # save details about line _to be_ executed lineno = None func_name = frame.f_code.co_name filename = frame.f_globals["__file__"] if (filename.endswith(".pyc") or filename.endswith(".pyo")): filename = filename[:-1] module_name = frame.f_globals["__name__"] lineno = frame.f_lineno return gpu_profile except Exception as e: print('A exception occured: {}'.format(e)) return gpu_profile def get_tensors(): for obj in gc.get_objects(): try: if torch.is_tensor(obj): tensor = obj else: continue if tensor.is_cuda: yield tensor except Exception as e: print('A exception occured: {}'.format(e))<br> 原创文章,转载请注明 :<a href="https://ptorch.com/news/181.html" target="_blank">如何在Pytorch中精细化利用显存以及提升Pytorch显存利用率 - pytorch中文网</a><br> 原文出处: https://ptorch.com/news/181.html<br> 问题交流群 :168117787 </div>
相关文章
相关标签/搜索