因为 Transformer 注意力机制对内存的需求是输入图像的二次方,因此这一方向还存在一些挑战。
近日,LambdaNetworks 的出现提供了一种解决此问题的方法,人们能够无需创建昂贵的注意力图便可捕捉长距离交互。这一方法在 ImageNet 上达到了新的业界最佳水平(state-of-the-art)。
git
论文连接:https://openreview.net/pdf?id=xTJEN-ggl1bgithub
GitHub连接:https://github.com/lucidrains/lambda-networks
对长程交互进行建模在机器学习中相当重要。注意力已成为捕获长程交互的一种经常使用范式。可是,自注意力二次方式的内存占用已经阻碍了其对长序列或多维输入(例如包含数万个像素的图像)的适用性。例如,将单个多头注意力层应用于一批 256 个64x64 (8 头)输入图像须要32GB的内存,这在实践中是不容许的。
框架
该研究提出了一种名为「lambda」的层,这些层提供了一种捕获输入和一组结构化上下文元素之间长程交互的通用框架。
lambda 层将可用上下文转换为单个线性函数(lambdas)。这些函数直接单独应用于每一个输入。研究者认为,lambda 层能够做为注意力机制的天然替代。注意力定义了输入元素和上下文元素之间的类似性核,而 lambda 层将上下文信息汇总为固定大小的线性函数,从而避免了对内存消耗大的注意力图的需求。这种对好比图1所示。
机器学习