《动手学深度学习》第十四天---多输入通道和多输出通道

以前咱们用到的输入和输出都是二维数组,但真实数据的维度常常更高。
好比彩色图像就有(RGB)三个通道。web

(一)多输入通道

下图数两个输入通道的卷积过程:
在这里插入图片描述过程大概就是在每一个通道上卷积核与输入进行互相关运算后,把获得的输出数组相加,这样获得的输出数组实际上是只有一个通道的。
通道数继续增长过程相似。
代码过程:算法

import d2lzh as d2l
from mxnet import nd

def corr2d_multi_in(X, K):
    # mxnet.ndarray.add_n(*args, **kwargs)实现的功能是;add_n(a1,a2,...,an)=a1+a2+...+an
    return nd.add_n(*[d2l.corr2d(x, k) for x, k in zip(X, K)])
    #首先利用以前的corr2d函数对每个通道的x和k进行互相关运算,而后利用add_n函数进行相加

在这里插入图片描述输出为:
在这里插入图片描述数组

(二)多输出通道

咱们知道上面的方法获得的输出通道只有一个,那么为了获得多输出通道,设卷积核输入通道数和输出通道数分别为ci和co,高和宽分别为kh和kw。若是但愿获得含多个通道的输出,咱们能够为每一个输出通道分别建立形状为ci×kh×kw的核数组。将它们在输出通道维上连结,卷积核的形状即co×ci×kh×kw。
在这个过程当中须要用到nd的stack函数,先补一下stack函数的功能:
在这里插入图片描述dom

def corr2d_multi_in_out(X, K):
    return nd.stack(*[corr2d_multi_in(X, k) for k in K])
    #  对于K的第一个维度co进行遍历,因此就是用每个维度为ci×kh×kw的卷积核与X卷积,这个过程与上面的多输入通道一致。
    #  获得了co个单通道输出数组后利用stack进行整合。
K = nd.stack(K, K + 1, K + 2)
#因为以前假设的2×2×2的形状,因此为了获得3通道的输出数组,把卷积核的第一维数变成3
#K+1是K的每一个元素加1
K.shape  #获得的K的形状如今应该是3×2×2×2
corr2d_multi_in_out(X, K)

在这里插入图片描述

(三)1×1卷积层

卷积层的长和宽均为1×1时称为1×1卷积层。假设将通道维看成特征维,将高和宽维度上的元素当成数据样本,那么1×1卷积层的做用与全链接层等价。通常用1×1卷积层来进行信道压缩,信道降维
怎么理解这个与全链接层等价呢?
在这里插入图片描述
从这张图上理解,输入有三个维度,输出有两个维度,因此卷积核为2×3×长宽。
浅蓝色的输出是由输入的浅蓝色方块与该维度的浅蓝色核相乘后相加获得的,形如
x1×w1+x2×w2+x3×w3。
从代码上理解:
先对X和K进行定义:svg

X = nd.random.uniform(shape=(3, 3, 3))
K = nd.random.uniform(shape=(2, 3, 1, 1))

X:
在这里插入图片描述
K:
在这里插入图片描述
再来看看计算过程:函数

def corr2d_multi_in_out_1x1(X, K):
    c_i, h, w = X.shape  #  c_i,h,w分别得到X的输入维度,长,宽
    c_o = K.shape[0]  #  c_o得到输出维度
    X = X.reshape((c_i, h * w))  #对X和K进行整形,便于dot运算,整形结果以下图
    K = K.reshape((c_o, c_i))
    Y = nd.dot(K, X)  # 全链接层的矩阵乘法
    return Y.reshape((c_o, h, w))

在这里插入图片描述对Y整形后:
在这里插入图片描述
在图中看到X和K整形后的结果,咱们来看一下dot运算和图解是否一致:
在dot运算中,
X中每一行第一个元素(黑色框)(其实就是未整形X的每一维的第一行第一个元素)与Y第一行相乘后相加获得的是Y中第一行第一个元素(其实也是Y的每一维的第一行第一个元素)
X中每一行第四个元素(红色框)(其实就是未整形X的每一维的第二行第一个元素)与Y第一行相乘后相加获得的是Y中第一行第四个元素(其实也是Y的每一维的第二行第一个元素)
X中每一行第七个元素(紫色框)(其实就是未整形X的每一维的第三行第一个元素)与Y第一行相乘后相加获得的是Y中第一行第七个元素(其实也是Y的每一维的第三行第一个元素)3d

能够看出这个算法和上面的图解以及全链接层的运算方法一致。code