Scaled Dot-Product Attention(transformer)

Scaled Dot-Product Attention是transformer的encoder的multi-head attention的组成部分。3d

因为Scaled Dot-Product Attention是multi-head的构成部分,所以Scaled Dot-Product Attention的数据的输入q,k,v的shape一般咱们会变化为以下:code

(batch, n_head, seqLen, dim)  其中n_head表示multi-head的个数,且n_head*dim = embedSizeorm

整个输入到输出,数据的维度保持不变。blog

temperature表示Scaled,即dim**0.5io

mask表示每一个batch对应样本中若是sequence为pad,则对应的mask为False,所以mask的初始维度为(batchSize, seqLen),为了计算,mask的维度会扩充为(batchSize, 1, 1, seqLen)。form

class ScaledDotProductAttention(nn.Module):
    """
    Compute 'Scaled Dot Product Attention
    """
    def forward(self, query, key, value, mask=None, dropout=None):  # (batch, n_head, seq_len, dim)
        scores = torch.matmul(query, key.transpose(-2, -1))/np.sqrt(query.size(-1))  # (batch, n_head, seq_len_q, seq_len_v)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        p_attn = F.softmax(scores, dim=-1)

        if dropout is not None:
            p_attn = dropout(p_attn)

        # (batch, n_head, seq_len_q, dim)
        return torch.matmul(p_attn, value), p_attn

注意:class

当QKV来自同一个向量的矩阵变换时称做self-attention;transform

当Q和KV来自不一样的向量的矩阵变换时叫soft-attention;im

相关文章
相关标签/搜索