自从张量(Tensor)计算这个概念出现后,神经网络的算法就能够看做是一系列的张量计算。所谓的张量,它本来是个数学概念,表示各类向量或者数值之间的关系。PyTorch的张量(torch.Tensor)表示的是N维矩阵与一维数组的关系。html
torch.Tensor的使用方法和numpy很类似(https://pytorch.org/...tensor-tutorial-py),二者惟一的区别在于torch.Tensor可使用GPU来计算,这就比用CPU的numpy要快不少。web
张量计算的种类有不少,好比加法、乘法、矩阵相乘、矩阵转置等,这些计算被称为算子(Operator),它们是PyTorch的核心组件。算法
算子的backend通常是C/C++的拓展程序,PyTorch的backend是称为"ATen"的C/C++库,ATen是"A Tensor"的缩写。数组
PyTorch全部的Operator都定义在Declarations.cwrap和native_functions.yaml这两个文件中,前者定义了从Torch那继承来的legacy operator(aten/src/TH),后者定义的是native operator,是PyTorch的operator。网络
相比于用C++开发的native code,legacy code是在PyTorch编译时由gen.py根据Declarations.cwrap的内容动态生成的。所以,若是你想要trace这些code,须要先编译PyTorch。ide
legacy code的开发要比native code复杂得多。若是能够的话,建议你尽可能避开它们。函数
本文会以矩阵相乘--torch.matmul()为例来分析PyTorch算子的工做流程。学习
我在深刻浅出全链接层(fully connected layer)中有讲在GPU层面是如何进行矩阵相乘的。Nvidia、AMD等公司提供了优化好的线性代数计算库--cuBLAS/rocBLAS/openBLAS,PyTorch只须要调用它们的API便可。优化
Figure 1是torch.matmul()在ATen中的function flow。能够看到,这个flow可不短,这主要是由于不一样类型的tensor(2d or Nd, batched gemm or not,with or without bias,cuda or cpu)的操做也不尽相同。spa
at::matmul()主要负责将Tensor转换成cuBLAS须要的格式。前面说过,Tensor能够是N维矩阵,若是tensor A是3d矩阵,tensor B是2d矩阵,就须要先将3d转成2d;若是它们都是>=3d的矩阵,就要考虑batched matmul的状况;若是bias=True,后续就应该交给at::addmm()来处理;总之,matmul要考虑的事情比想象中要多。
除此以外,不一样的dtype、device和layout须要调用不一样的操做函数,这部分工做交由c10::dispatcher来完成。
dispatcher主要用于动态调用dtype、device以及layout等方法函数。用过numpy的都知道,np.array()的数据类型有:float32, float16,int8,int32,.... 若是你了解C++就会知道,这类程序最适合用模板(template)来实现。
很遗憾,因为ATen有一部分operator是用C语言写的(从Torch继承过来),不支持模板功能,所以,就须要dispatcher这样的动态调度器。
相似地,PyTorch的tensor不只能够运行在GPU上,还能够跑在CPU、mkldnn和xla等设备,Figure 1中的dispatcher4就根据tensor的device调用了mm的GPU实现。
layout是指tensor中元素的排布。通常来讲,矩阵的排布都是紧凑型的,也就是strided layout。而那些有着大量0的稀疏矩阵,相应地就是sparse layout。
Figure 2是strided layout的演示实例,这里建立了一个2行2列的矩阵a,它的数据实际存放在一维数组(a.storage)里,2行2列只是这个数组的视图。
stride充当了从数组到视图的桥梁,好比,要打印第2行第2列的元素时,能够经过公式:\(1 * stride(0) + 1 * stride(1)\)来计算该元素在数组中的索引。
除了dtype、device、layout以外,dispatcher还能够用来调用legacy operator。好比说addmm这个operator,它的GPU实现就是经过dispatcher来跳转到legacy::cuda::_th_addmm。
到此,就完成了对PyTorch算子的学习。若是你要学习其余算子,能够先从aten/src/ATen/native目录的相关函数入手,从native_functions.yaml中找到dispatch目标函数,详情能够参考Figure 1。
更多精彩文章,欢迎扫码关注下方的公众号, 并访问个人简书博客:https://www.jianshu.com/u/c0fe8671254e
欢迎转发至朋友圈,工做号转载请后台留言申请受权~