基于BERT的超长文本分类模型

基于BERT的超长文本分类模型

 

0.Abstract

本文实现了一个基于BERT+LSTM超长文本分类的模型, 评估方法使用准确率和F1 Score.
项目代码github地址: https://github.com/neesetifa/bert_classification
git

1.任务介绍

用BERT作文本分类是一个比较常见的项目.
可是众所周知BERT对于文本输入长度有限制. 对于超长文本的处理, 最简单暴力无脑高效的办法是直接截断, 就取开头这部分送入BERT. 可是也请别看不起这种作法, 每每最简单,最Naive的方法效果反而比一顿操做猛如虎 复杂模型来得好.
github

这里多提一句为何. 一般长文本的文章结构都比较明确, 文章前面一两段基本都是对于后面的概述. 因此等于做者已经帮你提取了文章大意, 因此直接取前面一部分理论上来讲是有意义的.
固然也有最新研究代表取文章中间部分效果也很不错. 在此不展开.
网络

本文实现的是一种基于HIERARCHICAL(级联)思想的作法, 把文本切成多片处理. 该方法来自于这篇论文 <Hierarchical Transformers for Long Document Classification>.
文中提到这么作还能下降self-attention计算的时间复杂度.
假设原句子长为n, 每一个分段的长度是k. 咱们知道最原始的BERT计算时间复杂度是O(n2), 做者认为,这么作能够把时间复杂度下降到O(nk). 由于咱们把n分数据分割成k小份, 那么咱们一共要作n/k次, 每次咱们的时间复杂度是k2, 即O(n/k * k2) = O(nk)

app

数据集函数

此次咱们测试该模型在两种语言上的效果. 分别是中文数据集和英语数据集.
中文数据集依旧是咱们的老朋友ChineseNLPCorps提供的不一样类别商品的评论.
中文数据集传送门
英语数据集来源于Kaggle比赛, 用户对于不一样金融产品的评论.
英语数据集传送门
因为两种数据集训练预测上没有什么本质区别, 下文会用英语数据集来演示.




学习

评估方法测试

本项目使用的评估方法是准确率和F1 Score. 很是常见的分类问题评价标准.优化

测试集google

此项目中直接取了数据集里一小部分做为测试集.spa

2.数据初步处理

数据集里有55W条数据,18个features.
在这里插入图片描述
咱们须要的部分是product(即商品类别)以及consumer complaint narrative.
在这里插入图片描述
观察数据集,咱们发现用户评论是有NaN值的. 并且本次实验目的是作超长文本分类. 咱们选取非NaN值,而且是长度大于250的评论.



在这里插入图片描述
筛选完后咱们保留大约17k条左右数据
在这里插入图片描述

3.Baseline模型

咱们先来看一下什么都不作, 直接用BERT进行finetune能达到什么样的效果. 咱们以此做为实验的baseline.
本次预训练模型使用google官方的BERT-base-cased英语预训练模型(固然用uncased应该也不要紧, 我没有测试)
fine-tune部分很简单, 直接提取[CLS] token后过线性层, 是比较常规的套路. 损失函数使用cross entropy loss.
文本送入的最大长度定为250. 即前文里提到的"直接截取文本前面部分". 这次实验里咱们尝试比较HIERARCHICAL方法能比直接截取提升多少.
在这里插入图片描述
如图, 准确率达到了88%. 训练数据不过10k的数量级, 对于深度学习来讲是很是少的. 这里不得不感叹下BERT做为预训练模型在小样本数据上的实力很是强劲.




4. 数据进一步处理

接下来咱们进入提升部分. 首先对数据进一步处理.

分割文本

HIERARCHICAL思想本质是对数据进行有重叠(overlap)的分割. 这样分割后的每句句子之间仍然保留了必定的关联信息.

众所周知,BERT输入的最大长度限制为512, 其中还须要包括[CLS]和[SEP]. 那么实际可用的长度仅为510. 可是别忘了, 每一个单词tokenizer以后也有可能被分红好几部分. 因此实际可输入的句子长度远不足510.
本次实验里咱们设置分割的长度为200, overlap长度为50. 若是实际上线生产确有大量超过500长度的文本, 只需将分割和overlap长度设置更长便可.

def get_split_text(text, split_len=250, overlap_len=50):
	split_text=[]
	for w in range(len(text)//split_len):
  		if w == 0:   #第一次,直接分割长度放进去
    		text_piece = text[:split_len]
  		else:      # 不然, 按照(分割长度-overlap)日后走
  			window = split_len - overlap_len
    		text_piece = [w * window: w * window + split_len]
		split_text.append(text_piece)
	return split_text

分割完后长这样
在这里插入图片描述
随后咱们将这些分割的句子分离成单独的一条数据. 并为他们加上label.
在这里插入图片描述
对比原文本能够发现, index 1~ index4来源于同一句句子. 它被分割成了4份而且每份都拥有原文本的label.



4.最终模型

最终模型由两个部分构成, 第一部分是和baseline里如出一辙的, fine-tune后的BERT. 第二部分是由LSTM+FC层组成的混合模型.
即实际上, BERT只是用来提取出句子的表示, 而真正在作分类的是LSTM + FC部分(更准确来讲是FC部分, 由于LSTM模型部分仍然在作进一步的特征提取工做)
这里稍微提一句,这样作法我我的认为相似于广告推荐系统里GBDT+LR的组合. 采用一个稍微复杂的模型去作特征提取, 而后用一个相对简单的模型去预测.

第一部分: BERT

首先,咱们把分割好后的文本送入BERT进行训练. 这边我跑了5个epoch, 显卡仍然是Tesla K80, 每一个epoch大约须要23分钟左右.
在这里插入图片描述
接着, 咱们提取出这些文本的句子表示.
方便起见, 咱们这里仍然用[CLS] token做为句子表示. 固然也能够用sequence_output(在我上一个项目FAQ问答的最后结论中, 使用sequence_output的确能比pooled_output效果更好一点)
咱们得到的是这样一组数据:



句子1_a的embedding, label
句子1_b的embedding, label
句子1_c的embedding, label
句子2_a的embedding, label
句子2_b的embedding, label
句子3_a的embedding, label






随后咱们把这些embedding拼回起来, 变成了

[句子1_a的embedding,句子1_b的embedding, 句子1_c的embedding], label
[句子2_a的embedding, 句子2_b的embedding], label
[句子3_a的embedding, 句子3_b的embedding], label

这部分数据将做为LSTM部分的输入.

第二部分: LSTM + FC

这一步,咱们将上一步获得的embedding直接送入LSTM网络训练.

回想一下, 咱们平时用LSTM作, 是否是把句子过了embedding层以后再送入LSTM的? 这里咱们直接跳过embedding层, 由于咱们的数据自己就是embedding

因为分割后的embedding都不会太长, 咱们直接使用LSTM最后一个time step的输出(固然这里也有个尝试点, 若是提取出LSTM每一个time step的输出效果是否是会更好?)
LSTM以后会过一个激活函数, 接一个FC层, FC层和label用cross entropy loss进行优化.
因为合并后的数据量比较小, 我跑了10个epoch, 每次都很快.
在这里插入图片描述


最终效果和一些小节

(左边loss, 右边accuracy)
在这里插入图片描述
最终效果竟然提升到了94%!! 说实话这个提高量远高于论文. 可能和数据自己好也有关系. 可是咱们能够认为, 比起直接截取文本开头一段, 采用HIERARCHICAL方式不只克服了BERT长度限制的缺点, 也极大提高了对于超长文本的分类效果.

下面是在中文数据集上模型的baseline效果和提高后的效果.
(待跑)

因此我认为, 采用HIERARCHICAL方法, 提高/解决了BERT两方面的缺点:
1.下降了BERT里self-attention部分计算的时间复杂度. 就如开头所说, 时间复杂度从O(n2)下降到O(nk). 这个状况尤为适用于长度在500之内长度的文本.
2.克服了BERT对于输入文本长度有限的缺点. 对于tokenize以后长度超过510的文本, 也能够用此方式对准确率进行再提高, 其实际效果优于直接截断文本.

5. 进一步拓展: BERT + Transformer

原论文里还提到了使用Transformer代替LSTM做为预测部分. 这一节咱们用Transformer来试一下. 咱们先来分析一下使用Transformer结构后的时间复杂度. 显然它的时间复杂度和LSTM不同(LSTM复杂度咱们能够认为是线性的, 即O(n/k)~O(n).) 首先在BERT部分, 时间复杂度不变, 依旧为为O(n/k * k2) = O(nk). 进入到Transformer后,每一个sequence长度为n/k, 因此时间复杂度为O(n/k * n/k)=O(n2/k2). 那么整体时间复杂度为 O(nk) + O(n2/k2) ~ O(n2/k2). 相比于LSTM的O(nk), 这个O(n2/k2)复杂度是有至关的上升的. 可是咱们考虑到 n/k << n, 即n/k的量级远小于n, 因此仍是在可接受的范围. (本小节未完…)

相关文章
相关标签/搜索