关于transformer-xl中rel-shift实现的解读

  方法 抽象地看,我们要做的事情就是,给定一个矩阵,每行都进行左移,而移动的个数随行数递增而递减。 我目前想到的一种方法是使用gather,将想要的index提前定好,然后使用Pytorch的gather就能够实现。 而transformer-xl实现了另一种更好的方法:_rel_shift。 def _rel_shift(self, x, zero_triu=False): # x:
相关文章
相关标签/搜索