Transformer 模型的 PyTorch 实现

本文由罗周杨原创,转载请注明做者和出处。未经受权,不得用于商业用途。html

Google 2017年的论文 Attention is all you need 阐释了什么叫作大道至简!该论文提出了Transformer模型,彻底基于Attention mechanism,抛弃了传统的RNNCNNpython

咱们根据论文的结构图,一步一步使用 PyTorch 实现这个Transformer模型。git

Transformer架构

首先看一下transformer的结构图: github

transformer_architecture

解释一下这个结构图。首先,Transformer模型也是使用经典的encoer-decoder架构,由encoder和decoder两部分组成。网络

上图的左半边用Nx框出来的,就是咱们的encoder的一层。encoder一共有6层这样的结构。架构

上图的右半边用Nx框出来的,就是咱们的decoder的一层。decoder一共有6层这样的结构。app

输入序列通过word embeddingpositional encoding相加后,输入到encoder。函数

输出序列通过word embeddingpositional encoding相加后,输入到decoder。学习

最后,decoder输出的结果,通过一个线性层,而后计算softmax。ui

word embeddingpositional encoding我后面会解释。咱们首先详细地分析一下encoder和decoder的每一层是怎么样的。

Encoder

encoder由6层相同的层组成,每一层分别由两部分组成:

  • 第一部分是一个multi-head self-attention mechanism
  • 第二部分是一个position-wise feed-forward network,是一个全链接层

两个部分,都有一个 残差链接(residual connection),而后接着一个Layer Normalization

若是你是一个新手,你可能会问:

  • multi-head self-attention 是什么呢?
  • 参差结构是什么呢?
  • Layer Normalization又是什么?

这些问题咱们在后面会一一解答。

Decoder

和encoder相似,decoder由6个相同的层组成,每个层包括如下3个部分:

  • 第一个部分是multi-head self-attention mechanism
  • 第二部分是multi-head context-attention mechanism
  • 第三部分是一个position-wise feed-forward network

仍是和encoder相似,上面三个部分的每个部分,都有一个残差链接,后接一个Layer Normalization

可是,decoder出现了一个新的东西multi-head context-attention mechanism。这个东西其实也不复杂,理解了multi-head self-attention你就能够理解multi-head context-attention。这个咱们后面会讲解。

Attention机制

在讲清楚各类attention以前,咱们得先把attention机制说清楚。

通俗来讲,attention是指,对于某个时刻的输出y,它在输入x上各个部分的注意力。这个注意力实际上能够理解为权重

attention机制也能够分红不少种。Attention? Attention! 一问有一张比较全面的表格:

attention_mechanism
Figure 2. a summary table of several popular attention mechanisms.

上面第一种additive attention你可能听过。之前咱们的seq2seq模型里面,使用attention机制,这种**加性注意力(additive attention)**用的不少。Google的项目 tensorflow/nmt 里面使用的attention就是这种。

为何这种attention叫作additive attention呢?很简单,对于输入序列隐状态h_i和输出序列的隐状态s_t,它的处理方式很简单,直接合并,变成[s_t;h_i]

可是咱们的transformer模型使用的不是这种attention机制,使用的是另外一种,叫作乘性注意力(multiplicative attention)

那么这种乘性注意力机制是怎么样的呢?从上表中的公式也能够看出来:两个隐状态进行点积

Self-attention是什么?

到这里就能够解释什么是self-attention了。

上面咱们说attention机制的时候,都会说到两个隐状态,分别是h_is_t,前者是输入序列第i个位置产生的隐状态,后者是输出序列在第t个位置产生的隐状态。

所谓self-attention实际上就是,输出序列就是输入序列!所以,计算本身的attention得分,就叫作self-attention

Context-attention是什么?

知道了self-attention,那你确定猜到了context-attention是什么了:它是encoder和decoder之间的attention!因此,你也能够称之为encoder-decoder attention!

context-attention一词并非本人原创,有些文章或者代码会这样描述,我以为挺形象的,因此在此沿用这个称呼。其余文章可能会有其余名称,可是没关系,咱们抓住了重点便可,那就是两个不一样序列之间的attention,与self-attention相区别。

无论是self-attention仍是context-attention,它们计算attention分数的时候,能够选择不少方式,好比上面表中提到的:

  • additive attention
  • local-base
  • general
  • dot-product
  • scaled dot-product

那么咱们的Transformer模型,采用的是哪一种呢?答案是:scaled dot-product attention

Scaled dot-product attention是什么?

论文Attention is all you need里面对于attention机制的描述是这样的:

An attention function can be described as a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility of the query with the corresponding key.

这句话描述得很清楚了。翻译过来就是:经过肯定Q和K之间的类似程度来选择V

用公式来描述更加清晰:

\text{Attention}(Q,K,V)=softmax(\frac{QK^T}{\sqrt d_k})V

scaled dot-product attentiondot-product attention惟一的区别就是,scaled dot-product attention有一个缩放因子\frac{1}{\sqrt d_k}

上面公式中的d_k表示的是K的维度,在论文里面,默认是64

那么为何须要加上这个缩放因子呢?论文里给出了解释:对于d_k很大的时候,点积获得的结果维度很大,使得结果处于softmax函数梯度很小的区域。

咱们知道,梯度很小的状况,这对反向传播不利。为了克服这个负面影响,除以一个缩放因子,能够必定程度上减缓这种状况。

为何是\frac{1}{\sqrt d_k}呢?论文没有进一步说明。我的以为你可使用其余缩放因子,看看模型效果有没有提高。

论文也提供了一张很清晰的结构图,供你们参考:

scaled_dot_product_attention_arch
Figure 3. Scaled dot-product attention architecture.

首先说明一下咱们的K、Q、V是什么:

  • 在encoder的self-attention中,Q、K、V都来自同一个地方(相等),他们是上一层encoder的输出。对于第一层encoder,它们就是word embedding和positional encoding相加获得的输入。
  • 在decoder的self-attention中,Q、K、V都来自于同一个地方(相等),它们是上一层decoder的输出。对于第一层decoder,它们就是word embedding和positional encoding相加获得的输入。可是对于decoder,咱们不但愿它能得到下一个time step(即未来的信息),所以咱们须要进行sequence masking
  • 在encoder-decoder attention中,Q来自于decoder的上一层的输出,K和V来自于encoder的输出,K和V是同样的。
  • Q、K、V三者的维度同样,即 d_q=d_k=d_v

上面scaled dot-product attention和decoder的self-attention都出现了masking这样一个东西。那么这个mask究竟是什么呢?这两处的mask操做是同样的吗?这个问题在后面会有详细解释。

Scaled dot-product attention的实现

我们先把scaled dot-product attention实现了吧。代码以下:

import torch
import torch.nn as nn


class ScaledDotProductAttention(nn.Module):
    """Scaled dot-product attention mechanism."""

    def __init__(self, attention_dropout=0.0):
        super(ScaledDotProductAttention, self).__init__()
        self.dropout = nn.Dropout(attention_dropout)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v, scale=None, attn_mask=None):
        """前向传播. Args: q: Queries张量,形状为[B, L_q, D_q] k: Keys张量,形状为[B, L_k, D_k] v: Values张量,形状为[B, L_v, D_v],通常来讲就是k scale: 缩放因子,一个浮点标量 attn_mask: Masking张量,形状为[B, L_q, L_k] Returns: 上下文张量和attetention张量 """
        attention = torch.bmm(q, k.transpose(1, 2))
        if scale:
        	attention = attention * scale
        if attn_mask:
        	# 给须要mask的地方设置一个负无穷
        	attention = attention.masked_fill_(attn_mask, -np.inf)
		# 计算softmax
        attention = self.softmax(attention)
		# 添加dropout
        attention = self.dropout(attention)
		# 和V作点积
        context = torch.bmm(attention, v)
        return context, attention
复制代码

Multi-head attention又是什么呢?

理解了Scaled dot-product attention,Multi-head attention也很简单了。论文提到,他们发现将Q、K、V经过一个线性映射以后,分红 h 份,对每一份进行scaled dot-product attention效果更好。而后,把各个部分的结果合并起来,再次通过线性映射,获得最终的输出。这就是所谓的multi-head attention。上面的超参数 h 就是heads数量。论文默认是8

下面是multi-head attention的结构图:

multi-head attention_architecture
Figure 4: Multi-head attention architecture.

值得注意的是,上面所说的分红 h是在 d_k、d_q、d_v 维度上面进行切分的。所以,进入到scaled dot-product attention的 d_k 实际上等于未进入以前的 D_K/h

Multi-head attention容许模型加入不一样位置的表示子空间的信息。

Multi-head attention的公式以下:

\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_ 1,\dots,\text{head}_ h)W^O

其中,

\text{head}_ i = \text{Attention}(QW_i^Q,KW_i^K,VW_i^V)

论文里面,d_{model}=512h=8。因此在scaled dot-product attention里面的

d_q = d_k = d_v = d_{model}/h = 512/8 = 64

Multi-head attention的实现

相信你们已经理清楚了multi-head attention,那么咱们来实现它吧。代码以下:

import torch
import torch.nn as nn


class MultiHeadAttention(nn.Module):

    def __init__(self, model_dim=512, num_heads=8, dropout=0.0):
        super(MultiHeadAttention, self).__init__()

        self.dim_per_head = model_dim // num_heads
        self.num_heads = num_heads
        self.linear_k = nn.Linear(model_dim, self.dim_per_head * num_heads)
        self.linear_v = nn.Linear(model_dim, self.dim_per_head * num_heads)
        self.linear_q = nn.Linear(model_dim, self.dim_per_head * num_heads)

        self.dot_product_attention = ScaledDotProductAttention(dropout)
        self.linear_final = nn.Linear(model_dim, model_dim)
        self.dropout = nn.Dropout(dropout)
		# multi-head attention以后须要作layer norm
        self.layer_norm = nn.LayerNorm(model_dim)

    def forward(self, key, value, query, attn_mask=None):
		# 残差链接
        residual = query

        dim_per_head = self.dim_per_head
        num_heads = self.num_heads
        batch_size = key.size(0)

        # linear projection
        key = self.linear_k(key)
        value = self.linear_v(value)
        query = self.linear_q(query)

        # split by heads
        key = key.view(batch_size * num_heads, -1, dim_per_head)
        value = value.view(batch_size * num_heads, -1, dim_per_head)
        query = query.view(batch_size * num_heads, -1, dim_per_head)

        if attn_mask:
            attn_mask = attn_mask.repeat(num_heads, 1, 1)
        # scaled dot product attention
        scale = (key.size(-1) // num_heads) ** -0.5
        context, attention = self.dot_product_attention(
          query, key, value, scale, attn_mask)

        # concat heads
        context = context.view(batch_size, -1, dim_per_head * num_heads)

        # final linear projection
        output = self.linear_final(context)

        # dropout
        output = self.dropout(output)

        # add residual and norm layer
        output = self.layer_norm(residual + output)

        return output, attention

复制代码

上面的代码终于出现了Residual connectionLayer normalization。咱们如今来解释它们。

Residual connection是什么?

残差链接其实很简单!给你看一张示意图你就明白了:

residual_conn
Figure 5. Residual connection.

假设网络中某个层对输入x做用后的输出是F(x),那么增长residual connection以后,就变成了:

F(x)+x

这个+x操做就是一个shortcut

那么残差结构有什么好处呢?显而易见:由于增长了一项x,那么该层网络对x求偏导的时候,多了一个常数项1!因此在反向传播过程当中,梯度连乘,也不会形成梯度消失

因此,代码实现residual connection很很是简单:

def residual(sublayer_fn,x):
	return sublayer_fn(x)+x
复制代码

文章开始的transformer架构图中的Add & Norm中的Add也就是指的这个shortcut

至此,residual connection的问题理清楚了。更多关于残差网络的介绍能够看文末的参考文献。

Layer normalization是什么?

GRADIENTS, BATCH NORMALIZATION AND LAYER NORMALIZATION一文对normalization有很好的解释:

Normalization有不少种,可是它们都有一个共同的目的,那就是把输入转化成均值为0方差为1的数据。咱们在把数据送入激活函数以前进行normalization(归一化),由于咱们不但愿输入数据落在激活函数的饱和区。

说到normalization,那就确定得提到Batch Normalization。BN在CNN等地方用得不少。

BN的主要思想就是:在每一层的每一批数据上进行归一化。

咱们可能会对输入数据进行归一化,可是通过该网络层的做用后,咱们的的数据已经再也不是归一化的了。随着这种状况的发展,数据的误差愈来愈大,个人反向传播须要考虑到这些大的误差,这就迫使咱们只能使用较小的学习率来防止梯度消失或者梯度爆炸。

BN的具体作法就是对每一小批数据,在批这个方向上作归一化。以下图所示:

batch_normalization
Figure 6. Batch normalization example.(From theneuralperspective.com)

能够看到,右半边求均值是沿着数据批量N的方向进行的

Batch normalization的计算公式以下:

BN(x_i)=\alpha\times\frac{x_i-u_B}{\sqrt{\sigma_B^2+\epsilon}}+\beta

具体的实现能够查看上图的连接文章。

说完Batch normalization,就该说说我们今天的主角Layer normalization

那么什么是Layer normalization呢?:它也是归一化数据的一种方式,不过LN是在每个样本上计算均值和方差,而不是BN那种在批方向计算均值和方差

下面是LN的示意图:

layer_normalization
Figure 7. Layer normalization example.

和上面的BN示意图一比较就能够看出两者的区别啦!

下面看一下LN的公式,也BN十分类似:

LN(x_i)=\alpha\times\frac{x_i-u_L}{\sqrt{\sigma_L^2+\epsilon}}+\beta

Layer normalization的实现

上述两个参数\alpha\beta都是可学习参数。下面咱们本身来实现Layer normalization(PyTorch已经实现啦!)。代码以下:

import torch
import torch.nn as nn


class LayerNorm(nn.Module):
    """实现LayerNorm。其实PyTorch已经实现啦,见nn.LayerNorm。"""

    def __init__(self, features, epsilon=1e-6):
        """Init. Args: features: 就是模型的维度。论文默认512 epsilon: 一个很小的数,防止数值计算的除0错误 """
        super(LayerNorm, self).__init__()
        # alpha
        self.gamma = nn.Parameter(torch.ones(features))
        # beta
        self.beta = nn.Parameter(torch.zeros(features))
        self.epsilon = epsilon

    def forward(self, x):
        """前向传播. Args: x: 输入序列张量,形状为[B, L, D] """
        # 根据公式进行归一化
        # 在X的最后一个维度求均值,最后一个维度就是模型的维度
        mean = x.mean(-1, keepdim=True)
        # 在X的最后一个维度求方差,最后一个维度就是模型的维度
        std = x.std(-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.epsilon) + self.beta

复制代码

顺便提一句,Layer normalization多用于RNN这种结构。

Mask是什么?

如今终于轮到讲解mask了!mask顾名思义就是掩码,在咱们这里的意思大概就是对某些值进行掩盖,使其不产生效果

须要说明的是,咱们的Transformer模型里面涉及两种mask。分别是padding masksequence mask。其中后者咱们已经在decoder的self-attention里面见过啦!

其中,padding mask在全部的scaled dot-product attention里面都须要用到,而sequence mask只有在decoder的self-attention里面用到。

因此,咱们以前ScaledDotProductAttentionforward方法里面的参数attn_mask在不一样的地方会有不一样的含义。这一点咱们会在后面说明。

Padding mask

什么是padding mask呢?回想一下,咱们的每一个批次输入序列长度是不同的!也就是说,咱们要对输入序列进行对齐!具体来讲,就是给在较短的序列后面填充0。由于这些填充的位置,实际上是没什么意义的,因此咱们的attention机制不该该把注意力放在这些位置上,因此咱们须要进行一些处理。

具体的作法是,把这些位置的值加上一个很是大的负数(能够是负无穷),这样的话,通过softmax,这些位置的几率就会接近0

而咱们的padding mask其实是一个张量,每一个值都是一个Boolen,值为False的地方就是咱们要进行处理的地方。

下面是实现:

def padding_mask(seq_k, seq_q):
	# seq_k和seq_q的形状都是[B,L]
    len_q = seq_q.size(1)
    # `PAD` is 0
    pad_mask = seq_k.eq(0)
    pad_mask = pad_mask.unsqueeze(1).expand(-1, len_q, -1)  # shape [B, L_q, L_k]
    return pad_mask
复制代码

Sequence mask

文章前面也提到,sequence mask是为了使得decoder不能看见将来的信息。也就是对于一个序列,在time_step为t的时刻,咱们的解码输出应该只能依赖于t时刻以前的输出,而不能依赖t以后的输出。所以咱们须要想一个办法,把t以后的信息给隐藏起来。

那么具体怎么作呢?也很简单:产生一个上三角矩阵,上三角的值全为1,下三角的值权威0,对角线也是0。把这个矩阵做用在每个序列上,就能够达到咱们的目的啦。

具体的代码实现以下:

def sequence_mask(seq):
    batch_size, seq_len = seq.size()
    mask = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8),
                    diagonal=1)
    mask = mask.unsqueeze(0).expand(batch_size, -1, -1)  # [B, L, L]
    return mask
复制代码

哈佛大学的文章The Annotated Transformer有一张效果图:

sequence_mask
Figure 8. Sequence mask.

值得注意的是,原本mask只须要二维的矩阵便可,可是考虑到咱们的输入序列都是批量的,因此咱们要把本来二维的矩阵扩张成3维的张量。上面的代码能够看出,咱们已经进行了处理。

回到本小结开始的问题,attn_mask参数有几种状况?分别是什么意思?

  • 对于decoder的self-attention,里面使用到的scaled dot-product attention,同时须要padding masksequence mask做为attn_mask,具体实现就是两个mask相加做为attn_mask。
  • 其余状况,attn_mask一概等于padding mask

至此,mask相关的问题解决了。

Positional encoding是什么?

好了,终于要解释位置编码了,那就是文字开始的结构图提到的Positional encoding

就目前而言,咱们的Transformer架构彷佛少了点什么东西。没错,就是它对序列的顺序没有约束!咱们知道序列的顺序是一个很重要的信息,若是缺失了这个信息,可能咱们的结果就是:全部词语都对了,可是没法组成有意义的语句!

为了解决这个问题。论文提出了Positional encoding。这是啥?一句话归纳就是:对序列中的词语出现的位置进行编码!若是对位置进行编码,那么咱们的模型就能够捕捉顺序信息!

那么具体怎么作呢?论文的实现颇有意思,使用正余弦函数。公式以下:

PE(pos,2i) = sin(pos/10000^{2i/d_{model}})
PE(pos,2i+1) = cos(pos/10000^{2i/d_{model}})

其中,pos是指词语在序列中的位置。能够看出,在偶数位置,使用正弦编码,在奇数位置,使用余弦编码

上面公式中的d_{model}是模型的维度,论文默认是512

这个编码公式的意思就是:给定词语的位置\text{pos},咱们能够把它编码成d_{model}维的向量!也就是说,位置编码的每个维度对应正弦曲线,波长构成了从2\pi10000*2\pi的等比序列。

上面的位置编码是绝对位置编码。可是词语的相对位置也很是重要。这就是论文为何要使用三角函数的缘由!

正弦函数可以表达相对位置信息。,主要数学依据是如下两个公式:

sin(\alpha+\beta) = sin\alpha cos\beta + cos\alpha sin\beta
cos(\alpha+\beta) = cos\alpha cos\beta - sin\alpha sin\beta

上面的公式说明,对于词汇之间的位置偏移kPE(pos+k)能够表示成PE(pos)PE(k)的组合形式,这就是表达相对位置的能力!

以上就是PE的全部秘密。说完了positional encoding,那么咱们还有一个与之处于同一地位的word embedding

Word embedding你们都很熟悉了,它是对序列中的词汇的编码,把每个词汇编码成d_{model}维的向量!看到没有,Postional encoding是对词汇的位置编码,word embedding是对词汇自己编码

因此,我更喜欢positional encoding的另一个名字Positional embedding

Positional encoding的实现

PE的实现也不难,按照论文的公式便可。代码以下:

import torch
import torch.nn as nn


class PositionalEncoding(nn.Module):
    
    def __init__(self, d_model, max_seq_len):
        """初始化。 Args: d_model: 一个标量。模型的维度,论文默认是512 max_seq_len: 一个标量。文本序列的最大长度 """
        super(PositionalEncoding, self).__init__()
        
        # 根据论文给的公式,构造出PE矩阵
        position_encoding = np.array([
          [pos / np.pow(10000, 2.0 * (j // 2) / d_model) for j in range(d_model)]
          for pos in range(max_seq_len)])
        # 偶数列使用sin,奇数列使用cos
        position_encoding[:, 0::2] = np.sin(position_encoding[:, 0::2])
        position_encoding[:, 1::2] = np.cos(position_encoding[:, 1::2])

        # 在PE矩阵的第一行,加上一行全是0的向量,表明这`PAD`的positional encoding
        # 在word embedding中也常常会加上`UNK`,表明位置单词的word embedding,二者十分相似
        # 那么为何须要这个额外的PAD的编码呢?很简单,由于文本序列的长度不一,咱们须要对齐,
        # 短的序列咱们使用0在结尾补全,咱们也须要这些补全位置的编码,也就是`PAD`对应的位置编码
        pad_row = torch.zeros([1, d_model])
        position_encoding = torch.cat((pad_row, position_encoding))
        
        # 嵌入操做,+1是由于增长了`PAD`这个补全位置的编码,
        # Word embedding中若是词典增长`UNK`,咱们也须要+1。看吧,二者十分类似
        self.position_encoding = nn.Embedding(max_seq_len + 1, d_model)
        self.position_encoding.weight = nn.Parameter(position_encoding,
                                                     requires_grad=False)
    def forward(self, input_len):
        """神经网络的前向传播。 Args: input_len: 一个张量,形状为[BATCH_SIZE, 1]。每个张量的值表明这一批文本序列中对应的长度。 Returns: 返回这一批序列的位置编码,进行了对齐。 """
        
        # 找出这一批序列的最大长度
        max_len = torch.max(input_len)
        tensor = torch.cuda.LongTensor if input_len.is_cuda else torch.LongTensor
        # 对每个序列的位置进行对齐,在原序列位置的后面补上0
        # 这里range从1开始也是由于要避开PAD(0)的位置
        input_pos = tensor(
          [list(range(1, len + 1)) + [0] * (max_len - len) for len in input_len])
        return self.position_encoding(input_pos)
    
复制代码

Word embedding的实现

Word embedding应该是老生常谈了,它实际上就是一个二维浮点矩阵,里面的权重是可训练参数,咱们只须要把这个矩阵构建出来就完成了word embedding的工做。

因此,具体的实现很简单:

import torch.nn as nn


embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=0)
# 得到输入的词嵌入编码
seq_embedding = seq_embedding(inputs)*np.sqrt(d_model)
复制代码

上面vocab_size就是词典的大小,embedding_size就是词嵌入的维度大小,论文里面就是等于d_{model}=512。因此word embedding矩阵就是一个vocab_size*embedding_size的二维张量。

若是你想获取更详细的关于word embedding的信息,能够看个人另一个文章word2vec的笔记和实现

Position-wise Feed-Forward network是什么?

这就是一个全链接网络,包含两个线性变换和一个非线性函数(实际上就是ReLU)。公式以下:

FFN(x)=max(0,xW_1+b_1)W_2+b_2

这个线性变换在不一样的位置都表现地同样,而且在不一样的层之间使用不一样的参数。

论文提到,这个公式还能够用两个核大小为1的一维卷积来解释,卷积的输入输出都是d_{model}=512,中间层的维度是d_{ff}=2048

实现以下:

import torch
import torch.nn as nn


class PositionalWiseFeedForward(nn.Module):

    def __init__(self, model_dim=512, ffn_dim=2048, dropout=0.0):
        super(PositionalWiseFeedForward, self).__init__()
        self.w1 = nn.Conv1d(model_dim, ffn_dim, 1)
        self.w2 = nn.Conv1d(model_dim, ffn_dim, 1)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(model_dim)

    def forward(self, x):
        output = x.transpose(1, 2)
        output = self.w2(F.relu(self.w1(output)))
        output = self.dropout(output.transpose(1, 2))

        # add residual and norm layer
        output = self.layer_norm(x + output)
        return output
复制代码

Transformer的实现

至此,全部的细节都已经解释完了。如今来完成咱们Transformer模型的代码。

首先,咱们须要实现6层的encoder和decoder。

encoder代码实现以下:

import torch
import torch.nn as nn


class EncoderLayer(nn.Module):
	"""Encoder的一层。"""

    def __init__(self, model_dim=512, num_heads=8, ffn_dim=2018, dropout=0.0):
        super(EncoderLayer, self).__init__()

        self.attention = MultiHeadAttention(model_dim, num_heads, dropout)
        self.feed_forward = PositionalWiseFeedForward(model_dim, ffn_dim, dropout)

    def forward(self, inputs, attn_mask=None):

        # self attention
        context, attention = self.attention(inputs, inputs, inputs, padding_mask)

        # feed forward network
        output = self.feed_forward(context)

        return output, attention


class Encoder(nn.Module):
	"""多层EncoderLayer组成Encoder。"""

    def __init__(self, vocab_size, max_seq_len, num_layers=6, model_dim=512, num_heads=8, ffn_dim=2048, dropout=0.0):
        super(Encoder, self).__init__()

        self.encoder_layers = nn.ModuleList(
          [EncoderLayer(model_dim, num_heads, ffn_dim, dropout) for _ in
           range(num_layers)])

        self.seq_embedding = nn.Embedding(vocab_size + 1, model_dim, padding_idx=0)
        self.pos_embedding = PositionalEncoding(model_dim, max_seq_len)

    def forward(self, inputs, inputs_len):
        output = self.seq_embedding(inputs)
        output += self.pos_embedding(inputs_len)

        self_attention_mask = padding_mask(inputs, inputs)

        attentions = []
        for encoder in self.encoder_layers:
            output, attention = encoder(output, self_attention_mask)
            attentions.append(attention)

        return output, attentions

复制代码

经过文章前面的分析,代码不须要更多解释了。一样的,咱们的decoder代码以下:

import torch
import torch.nn as nn


class DecoderLayer(nn.Module):

    def __init__(self, model_dim, num_heads=8, ffn_dim=2048, dropout=0.0):
        super(DecoderLayer, self).__init__()

        self.attention = MultiHeadAttention(model_dim, num_heads, dropout)
        self.feed_forward = PositionalWiseFeedForward(model_dim, ffn_dim, dropout)

    def forward(self, dec_inputs, enc_outputs, self_attn_mask=None, context_attn_mask=None):
        # self attention, all inputs are decoder inputs
        dec_output, self_attention = self.attention(
          dec_inputs, dec_inputs, dec_inputs, self_attn_mask)

        # context attention
        # query is decoder's outputs, key and value are encoder's inputs
        dec_output, context_attention = self.attention(
          enc_outputs, enc_outputs, dec_output, context_attn_mask)

        # decoder's output, or context
        dec_output = self.feed_forward(dec_output)

        return dec_output, self_attention, context_attention


class Decoder(nn.Module):

    def __init__(self, vocab_size, max_seq_len, num_layers=6, model_dim=512, num_heads=8, ffn_dim=2048, dropout=0.0):
        super(Decoder, self).__init__()

        self.num_layers = num_layers

        self.decoder_layers = nn.ModuleList(
          [DecoderLayer(model_dim, num_heads, ffn_dim, dropout) for _ in
           range(num_layers)])

        self.seq_embedding = nn.Embedding(vocab_size + 1, model_dim, padding_idx=0)
        self.pos_embedding = PositionalEncoding(model_dim, max_seq_len)

    def forward(self, inputs, inputs_len, enc_output, context_attn_mask=None):
        output = self.seq_embedding(inputs)
        output += self.pos_embedding(inputs_len)

        self_attention_padding_mask = padding_mask(inputs, inputs)
        seq_mask = sequence_mask(inputs)
        self_attn_mask = torch.gt((self_attention_padding_mask + seq_mask), 0)

        self_attentions = []
        context_attentions = []
        for decoder in self.decoder_layers:
            output, self_attn, context_attn = decoder(
            output, enc_output, self_attn_mask, context_attn_mask)
            self_attentions.append(self_attn)
            context_attentions.append(context_attn)

        return output, self_attentions, context_attentions
复制代码

最后,咱们把encoder和decoder组成Transformer模型!

代码以下:

import torch
import torch.nn as nn


class Transformer(nn.Module):

    def __init__(self, src_vocab_size, src_max_len, tgt_vocab_size, tgt_max_len, num_layers=6, model_dim=512, num_heads=8, ffn_dim=2048, dropout=0.2):
        super(Transformer, self).__init__()

        self.encoder = Encoder(src_vocab_size, src_max_len, num_layers, model_dim,
                               num_heads, ffn_dim, dropout)
        self.decoder = Decoder(tgt_vocab_size, tgt_max_len, num_layers, model_dim,
                               num_heads, ffn_dim, dropout)

        self.linear = nn.Linear(model_dim, tgt_vocab_size, bias=False)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, src_seq, src_len, tgt_seq, tgt_len):
        context_attn_mask = padding_mask(tgt_seq, src_seq)

        output, enc_self_attn = self.encoder(src_seq, src_len)

        output, dec_self_attn, ctx_attn = self.decoder(
          tgt_seq, tgt_len, output, context_attn_mask)

        output = self.linear(output)
        output = self.softmax(output)

        return output, enc_self_attn, dec_self_attn, ctx_attn

复制代码

至此,Transformer模型已经实现了!

参考文章

1.为何ResNet和DenseNet能够这么深?一文详解残差块为什么有助于解决梯度弥散问题
2.GRADIENTS, BATCH NORMALIZATION AND LAYER NORMALIZATION
3.The Annotated Transformer
4.Building the Mighty Transformer for Sequence Tagging in PyTorch : Part I
5.Building the Mighty Transformer for Sequence Tagging in PyTorch : Part II
6.Attention?Attention!

参考代码

1.jadore801120/attention-is-all-you-need-pytorch
2.JayParks/transformer

联系我

相关文章
相关标签/搜索