Sentence-BERT详解

简述

BERT和RoBERTa在文本语义类似度(Semantic Textual Similarity)等句子对的回归任务上,已经达到了SOTA的结果。可是,它们都须要把两个句子同时送入网络,这样会致使巨大的计算开销:从10000个句子中找出最类似的句子对,大概须要5000万( C 10000 2 = 49 , 995 , 000 C_{10000}^2=49,995,000 )个推理计算,在V100GPU上耗时约65个小时。这种结构使得BERT不适合语义类似度搜索,一样也不适合无监督任务,例如聚类php

解决聚类和语义搜索的一种常见方法是将每一个句子映射到一个向量空间,使得语义类似的句子很接近。一般得到句子向量的方法有两种:html

  1. 计算全部Token输出向量的平均值
  2. 使用[CLS]位置输出的向量

然而,UKP的研究员实验发现,在文本类似度(STS)任务上,使用上述两种方法获得的效果却并很差,即便是Glove向量也明显优于朴素的BERT句子embeddings(见下图前三行)git

Sentence-BERT(SBERT)的做者对预训练的BERT进行修改:使用**Siamese and Triplet Network(孪生网络和三胞胎网络)**生成具备语义的句子Embedding向量。语义相近的句子,其Embedding向量距离就比较近,从而可使用余弦类似度、曼哈顿距离、欧氏距离等找出语义类似的句子。SBERT在保证准确性的同时,可将上述提到BERT/RoBERTa的65小时下降到5秒(计算余弦类似度大概0.01秒)。这样SBERT能够完成某些新的特定任务,好比聚类、基于语义的信息检索等github

模型介绍

Pooling策略

SBERT在BERT/RoBERTa的输出结果上增长了一个Pooling操做,从而生成一个固定维度的句子Embedding。实验中采起了三种Pooling策略作对比:markdown

  1. CLS:直接用CLS位置的输出向量做为整个句子向量
  2. MEAN:计算全部Token输出向量的平均值做为整个句子向量
  3. MAX:取出全部Token输出向量各个维度的最大值做为整个句子向量

三种策略的实验对比效果以下网络

由结果可见,MEAN的效果是最好的,因此后面实验默认采用的也是MEAN策略app

模型结构

为了可以fine-tune BERT/RoBERTa,文章采用了孪生网络和三胞胎网络来更新参数,以达到生成的句子向量更具语义信息。该网络结构取决于具体的训练数据,文中实验了下面几种机构和目标函数函数

Classification Objective Function

针对分类问题,做者将向量 u , v , u v u,v,|u-v| 三个向量拼接在一块儿,而后乘以一个权重参数 W t R 3 n × k W_t\in \mathbb{R}^{3n\times k} ,其中 n n 表示向量的维度, k k 表示label的数量oop

o = s o f t m a x ( W t [ u ; v ; u v ] ) o = softmax(W_t[u;v;|u-v|])

损失函数为CrossEntropyLoss优化

注:原文公式为 s o f t m a x ( W t ( u , v , u v ) ) softmax(W_t(u,v,|u-v|)) ,我我的比较喜欢用 [ ; ; ] [;;] 表示向量拼接的意思

Regression Objective Function

两个句子embedding向量 u , v u,v 的余弦类似度计算结构以下所示,损失函数为MAE(mean squared error)

Triplet Objective Function

更多关于Triplet Network的内容能够看个人这篇Siamese Network & Triplet NetWork。给定一个主句 a a ,一个正面句子 p p 和一个负面句子 n n ,三元组损失调整网络,使得 a a p p 之间的距离尽量小, a a n n 之间的距离尽量大。数学上,咱们指望最小化如下损失函数:

m a x ( s a s p s a s n + ϵ , 0 ) max(||s_a-s_p||-||s_a-s_n||+\epsilon, 0)

其中, s x s_x 表示句子 x x 的embedding, ||·|| 表示距离,边缘参数 ϵ \epsilon 表示 s a s_a s p s_p 的距离至少应比 s a s_a s n s_n 的距离近 ϵ \epsilon 。在实验中,使用欧式距离做为距离度量, ϵ \epsilon 设置为1

模型训练细节

做者训练时结合了SNLI(Stanford Natural Language Inference)和Multi-Genre NLI两种数据集。SNLI有570,000我的工标注的句子对,标签分别为矛盾,蕴含(eintailment),中立三种;MultiNLI是SNLI的升级版,格式和标签都同样,有430,000个句子对,主要是一系列口语和书面语文本

蕴含关系描述的是两个文本之间的推理关系,其中一个文本做为前提(Premise),另外一个文本做为假设(Hypothesis),若是根据前提可以推理得出假设,那么就说前提蕴含假设。参考样例以下:

Sentence A (Premise) Sentence B (Hypothesis) Label
A soccer game with multiple males playing. Some men are playing a sport. entailment
An older and younger man smiling. Two men are smiling and laughing at the cats playing on the floor. neutral
A man inspects the uniform of a figure in some East Asian country. The man is sleeping. contradiction

实验时,做者使用类别为3的softmax分类目标函数对SBERT进行fine-tune,batch_size=16,Adam优化器,learning_rate=2e-5

消融研究

为了对SBERT的不一样方面进行消融研究,以便更好地了解它们的相对重要性,咱们在SNLI和Multi-NLI数据集上构建了分类模型,在STS benchmark数据集上构建了回归模型。在pooling策略上,对比了MEAN、MAX、CLS三种策略;在分类目标函数中,对比了不一样的向量组合方式。结果以下

结果代表,Pooling策略影响较小,向量组合策略影响较大,而且 [ u ; v ; u v ] [u;v;|u-v|] 效果最好

Reference

相关文章
相关标签/搜索