基于BERT的超长文本分类模型
0.Abstract
本文实现了一个基于BERT+LSTM超长文本分类的模型, 评估方法使用准确率和F1 Score.
项目代码github地址: https://github.com/neesetifa/bert_classification
git
1.任务介绍
用BERT作文本分类是一个比较常见的项目.
可是众所周知BERT对于文本输入长度有限制. 对于超长文本的处理, 最简单暴力无脑高效的办法是直接截断, 就取开头这部分送入BERT. 可是也请别看不起这种作法, 每每最简单,最Naive的方法效果反而比一顿操做猛如虎 复杂模型来得好.
github
这里多提一句为何. 一般长文本的文章结构都比较明确, 文章前面一两段基本都是对于后面的概述. 因此等于做者已经帮你提取了文章大意, 因此直接取前面一部分理论上来讲是有意义的.
固然也有最新研究代表取文章中间部分效果也很不错. 在此不展开.
网络
本文实现的是一种基于HIERARCHICAL(级联)思想的作法, 把文本切成多片处理. 该方法来自于这篇论文 <Hierarchical Transformers for Long Document Classification>.
文中提到这么作还能下降self-attention计算的时间复杂度.
假设原句子长为n, 每一个分段的长度是k. 咱们知道最原始的BERT计算时间复杂度是O(n2), 做者认为,这么作能够把时间复杂度下降到O(nk). 由于咱们把n分数据分割成k小份, 那么咱们一共要作n/k次, 每次咱们的时间复杂度是k2, 即O(n/k * k2) = O(nk)
app
此次咱们测试该模型在两种语言上的效果. 分别是中文数据集和英语数据集.
中文数据集依旧是咱们的老朋友ChineseNLPCorps提供的不一样类别商品的评论.
中文数据集传送门
英语数据集来源于Kaggle比赛, 用户对于不一样金融产品的评论.
英语数据集传送门
因为两种数据集训练预测上没有什么本质区别, 下文会用英语数据集来演示.
学习
本项目使用的评估方法是准确率和F1 Score. 很是常见的分类问题评价标准.优化
此项目中直接取了数据集里一小部分做为测试集.spa
2.数据初步处理
数据集里有55W条数据,18个features.
咱们须要的部分是product(即商品类别)以及consumer complaint narrative.
观察数据集,咱们发现用户评论是有NaN值的. 并且本次实验目的是作超长文本分类. 咱们选取非NaN值,而且是长度大于250的评论.
筛选完后咱们保留大约17k条左右数据
3.Baseline模型
咱们先来看一下什么都不作, 直接用BERT进行finetune能达到什么样的效果. 咱们以此做为实验的baseline.
本次预训练模型使用google官方的BERT-base-cased英语预训练模型(固然用uncased应该也不要紧, 我没有测试)
fine-tune部分很简单, 直接提取[CLS] token后过线性层, 是比较常规的套路. 损失函数使用cross entropy loss.
文本送入的最大长度定为250. 即前文里提到的"直接截取文本前面部分". 这次实验里咱们尝试比较HIERARCHICAL方法能比直接截取提升多少.
如图, 准确率达到了88%. 训练数据不过10k的数量级, 对于深度学习来讲是很是少的. 这里不得不感叹下BERT做为预训练模型在小样本数据上的实力很是强劲.
4. 数据进一步处理
接下来咱们进入提升部分. 首先对数据进一步处理.
HIERARCHICAL思想本质是对数据进行有重叠(overlap)的分割. 这样分割后的每句句子之间仍然保留了必定的关联信息.
众所周知,BERT输入的最大长度限制为512, 其中还须要包括[CLS]和[SEP]. 那么实际可用的长度仅为510. 可是别忘了, 每一个单词tokenizer以后也有可能被分红好几部分. 因此实际可输入的句子长度远不足510.
本次实验里咱们设置分割的长度为200, overlap长度为50. 若是实际上线生产确有大量超过500长度的文本, 只需将分割和overlap长度设置更长便可.
def get_split_text(text, split_len=250, overlap_len=50): split_text=[] for w in range(len(text)//split_len): if w == 0: #第一次,直接分割长度放进去 text_piece = text[:split_len] else: # 不然, 按照(分割长度-overlap)日后走 window = split_len - overlap_len text_piece = [w * window: w * window + split_len] split_text.append(text_piece) return split_text
分割完后长这样
随后咱们将这些分割的句子分离成单独的一条数据. 并为他们加上label.
对比原文本能够发现, index 1~ index4来源于同一句句子. 它被分割成了4份而且每份都拥有原文本的label.
4.最终模型
最终模型由两个部分构成, 第一部分是和baseline里如出一辙的, fine-tune后的BERT. 第二部分是由LSTM+FC层组成的混合模型.
即实际上, BERT只是用来提取出句子的表示, 而真正在作分类的是LSTM + FC部分(更准确来讲是FC部分, 由于LSTM模型部分仍然在作进一步的特征提取工做)
这里稍微提一句,这样作法我我的认为相似于广告推荐系统里GBDT+LR的组合. 采用一个稍微复杂的模型去作特征提取, 而后用一个相对简单的模型去预测.
首先,咱们把分割好后的文本送入BERT进行训练. 这边我跑了5个epoch, 显卡仍然是Tesla K80, 每一个epoch大约须要23分钟左右.
接着, 咱们提取出这些文本的句子表示.
方便起见, 咱们这里仍然用[CLS] token做为句子表示. 固然也能够用sequence_output(在我上一个项目FAQ问答的最后结论中, 使用sequence_output的确能比pooled_output效果更好一点)
咱们得到的是这样一组数据:
句子1_a的embedding, label
句子1_b的embedding, label
句子1_c的embedding, label
句子2_a的embedding, label
句子2_b的embedding, label
句子3_a的embedding, label
…
随后咱们把这些embedding拼回起来, 变成了
[句子1_a的embedding,句子1_b的embedding, 句子1_c的embedding], label
[句子2_a的embedding, 句子2_b的embedding], label
[句子3_a的embedding, 句子3_b的embedding], label
这部分数据将做为LSTM部分的输入.
这一步,咱们将上一步获得的embedding直接送入LSTM网络训练.
回想一下, 咱们平时用LSTM作, 是否是把句子过了embedding层以后再送入LSTM的? 这里咱们直接跳过embedding层, 由于咱们的数据自己就是embedding
因为分割后的embedding都不会太长, 咱们直接使用LSTM最后一个time step的输出(固然这里也有个尝试点, 若是提取出LSTM每一个time step的输出效果是否是会更好?)
LSTM以后会过一个激活函数, 接一个FC层, FC层和label用cross entropy loss进行优化.
因为合并后的数据量比较小, 我跑了10个epoch, 每次都很快.
(左边loss, 右边accuracy)
最终效果竟然提升到了94%!! 说实话这个提高量远高于论文. 可能和数据自己好也有关系. 可是咱们能够认为, 比起直接截取文本开头一段, 采用HIERARCHICAL方式不只克服了BERT长度限制的缺点, 也极大提高了对于超长文本的分类效果.
下面是在中文数据集上模型的baseline效果和提高后的效果.
(待跑)
因此我认为, 采用HIERARCHICAL方法, 提高/解决了BERT两方面的缺点:
1.下降了BERT里self-attention部分计算的时间复杂度. 就如开头所说, 时间复杂度从O(n2)下降到O(nk). 这个状况尤为适用于长度在500之内长度的文本.
2.克服了BERT对于输入文本长度有限的缺点. 对于tokenize以后长度超过510的文本, 也能够用此方式对准确率进行再提高, 其实际效果优于直接截断文本.
5. 进一步拓展: BERT + Transformer
原论文里还提到了使用Transformer代替LSTM做为预测部分. 这一节咱们用Transformer来试一下. 咱们先来分析一下使用Transformer结构后的时间复杂度. 显然它的时间复杂度和LSTM不同(LSTM复杂度咱们能够认为是线性的, 即O(n/k)~O(n).) 首先在BERT部分, 时间复杂度不变, 依旧为为O(n/k * k2) = O(nk). 进入到Transformer后,每一个sequence长度为n/k, 因此时间复杂度为O(n/k * n/k)=O(n2/k2). 那么整体时间复杂度为 O(nk) + O(n2/k2) ~ O(n2/k2). 相比于LSTM的O(nk), 这个O(n2/k2)复杂度是有至关的上升的. 可是咱们考虑到 n/k << n, 即n/k的量级远小于n, 因此仍是在可接受的范围. (本小节未完…)