1 大纲概述html
文本分类这个系列将会有十篇左右,包括基于word2vec预训练的文本分类,与及基于最新的预训练模型(ELMo,BERT等)的文本分类。总共有如下系列:python
textCNN 模型github
charCNN 模型json
Bi-LSTM 模型app
RCNN 模型函数
全部代码均在textClassifier仓库中。
2 数据集
数据集为IMDB 电影影评,总共有三个数据文件,在/data/rawData目录下,包括unlabeledTrainData.tsv,labeledTrainData.tsv,testData.tsv。在进行文本分类时须要有标签的数据(labeledTrainData),数据预处理如文本分类实战(一)—— word2vec预训练词向量中同样,预处理后的文件为/data/preprocess/labeledTrain.csv。
3 BERT预训练模型
BERT 模型来源于论文BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding。BERT模型是谷歌提出的基于双向Transformer构建的语言模型。BERT模型和ELMo有大不一样,在以前的预训练模型(包括word2vec,ELMo等)都会生成词向量,这种类别的预训练模型属于domain transfer。而近一两年提出的ULMFiT,GPT,BERT等都属于模型迁移。
BERT 模型是将预训练模型和下游任务模型结合在一块儿的,也就是说在作下游任务时仍然是用BERT模型,并且自然支持文本分类任务,在作文本分类任务时不须要对模型作修改。谷歌提供了下面七种预训练好的模型文件。
BERT模型在英文数据集上提供了两种大小的模型,Base和Large。Uncased是意味着输入的词都会转变成小写,cased是意味着输入的词会保存其大写(在命名实体识别等项目上须要)。Multilingual是支持多语言的,最后一个是中文预训练模型。
在这里咱们选择BERT-Base,Uncased。下载下来以后是一个zip文件,解压后有ckpt文件,一个模型参数的json文件,一个词汇表txt文件。
在应用BERT模型以前,咱们须要去github上下载开源代码,咱们能够直接clone下来,在这里有一个run_classifier.py文件,在作文本分类项目时,咱们须要修改这个文件,主要是添加咱们的数据预处理类。clone下来的项目结构以下:
在run_classifier.py文件中有一个基类DataProcessor类,其代码以下:
class DataProcessor(object): """Base class for data converters for sequence classification data sets.""" def get_train_examples(self, data_dir): """Gets a collection of `InputExample`s for the train set.""" raise NotImplementedError() def get_dev_examples(self, data_dir): """Gets a collection of `InputExample`s for the dev set.""" raise NotImplementedError() def get_test_examples(self, data_dir): """Gets a collection of `InputExample`s for prediction.""" raise NotImplementedError() def get_labels(self): """Gets the list of labels for this data set.""" raise NotImplementedError() @classmethod def _read_tsv(cls, input_file, quotechar=None): """Reads a tab separated value file.""" with tf.gfile.Open(input_file, "r") as f: reader = csv.reader(f, delimiter="\t", quotechar=quotechar) lines = [] for line in reader: lines.append(line) return lines
在这个基类中定义了一个读取文件的静态方法_read_tsv,四个分别获取训练集,验证集,测试集和标签的方法。接下来咱们要定义本身的数据处理的类,咱们将咱们的类命名为IMDBProcessor
class IMDBProcessor(DataProcessor): """ IMDB data processor """ def _read_csv(self, data_dir, file_name): with tf.gfile.Open(data_dir + file_name, "r") as f: reader = csv.reader(f, delimiter=",", quotechar=None) lines = [] for line in reader: lines.append(line) return lines def get_train_examples(self, data_dir): lines = self._read_csv(data_dir, "trainData.csv") examples = [] for (i, line) in enumerate(lines): if i == 0: continue guid = "train-%d" % (i) text_a = tokenization.convert_to_unicode(line[0]) label = tokenization.convert_to_unicode(line[1]) examples.append( InputExample(guid=guid, text_a=text_a, label=label)) return examples def get_dev_examples(self, data_dir): lines = self._read_csv(data_dir, "devData.csv") examples = [] for (i, line) in enumerate(lines): if i == 0: continue guid = "dev-%d" % (i) text_a = tokenization.convert_to_unicode(line[0]) label = tokenization.convert_to_unicode(line[1]) examples.append( InputExample(guid=guid, text_a=text_a, label=label)) return examples def get_test_examples(self, data_dir): lines = self._read_csv(data_dir, "testData.csv") examples = [] for (i, line) in enumerate(lines): if i == 0: continue guid = "test-%d" % (i) text_a = tokenization.convert_to_unicode(line[0]) label = tokenization.convert_to_unicode(line[1]) examples.append( InputExample(guid=guid, text_a=text_a, label=label)) return examples def get_labels(self): return ["0", "1"]
在这里咱们没有直接用基类中的静态方法_read_tsv,由于咱们的csv文件是用逗号分隔的,所以就本身定义了一个_read_csv的方法,其他的方法就是读取训练集,验证集,测试集和标签。在这里标签就是一个列表,将咱们的类别标签放入就行。训练集,验证集和测试集都是返回一个InputExample对象的列表。InputExample是run_classifier.py中定义的一个类,代码以下:
class InputExample(object): """A single training/test example for simple sequence classification.""" def __init__(self, guid, text_a, text_b=None, label=None): """Constructs a InputExample. Args: guid: Unique id for the example. text_a: string. The untokenized text of the first sequence. For single sequence tasks, only this sequence must be specified. text_b: (Optional) string. The untokenized text of the second sequence. Only must be specified for sequence pair tasks. label: (Optional) string. The label of the example. This should be specified for train and dev examples, but not for test examples. """ self.guid = guid self.text_a = text_a self.text_b = text_b self.label = label
在这里定义了text_a和text_b,说明是支持句子对的输入的,不过咱们这里作文本分类只有一个句子的输入,所以text_b能够不传参。
另外从上面咱们自定义的数据处理类中能够看出,训练集和验证集是保存在不一样文件中的,所以咱们须要将咱们以前预处理好的数据提早分割成训练集和验证集,并存放在同一个文件夹下面,文件的名称要和类中方法里的名称相同。
到这里以后咱们已经准备好了咱们的数据集,并定义好了数据处理类,此时咱们须要将咱们的数据处理类加入到run_classifier.py文件中的main函数下面的processors字典中,结果以下:
以后就能够直接执行run_classifier.py文件,执行脚本以下:
export BERT_BASE_DIR=../modelParams/uncased_L-12_H-768_A-12 export DATASET=../data/ python run_classifier.py \ --data_dir=$MY_DATASET \ --task_name=imdb \ --vocab_file=$BERT_BASE_DIR/vocab.txt \ --bert_config_file=$BERT_BASE_DIR/bert_config.json \ --output_dir=../output/ \ --do_train=true \ --do_eval=true \ --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ --max_seq_length=200 \ --train_batch_size=16 \ --learning_rate=5e-5\ --num_train_epochs=2.0
在这里的task_name就是咱们定义的数据处理类的键,BERT模型较大,加载时须要较大的内存,若是出现内存溢出的问题,能够适当的下降batch_size的值。
目前迭代完以后的输出比较少,并且只有等迭代结束后才会有结果输出,不利于观察损失的变化,后续将修改输出。目前的输出结果:
测试集上的准确率达到了90.7% ,这个结果比Bi-LSTM + Attention(87.7%)的结果要好。
4 增长验证集输出的指标值
目前验证集上的输出指标值只有loss和accuracy,如上图所示,然而在分类时,咱们可能还须要看auc,recall,precision的值。增长几行代码就能够搞定:
在个人代码中743行这里有个metric_fn函数,以前这个函数下只有loss和accuracy的计算,咱们在这里加上auc,recall,precision的计算,而后加入到return的这个字典中就能够了。如今的输出结果:
5 关于BERT的问题
在run_classifier.py文件中,训练模型,验证模型都是用的tensorflow中的estimator接口,所以咱们没法实如今训练迭代100步就用验证集验证一次,在run_classifier.py文件中提供的方法是先运行完全部的epochs以后,再加载模型进行验证。训练模型时的代码:
在个人代码中948行这里,在这里咱们加入了几行代码,能够实现训练时输出loss,就是上面的:
tensors_to_log = {"train loss": "loss/Mean:0"} logging_hook = tf.train.LoggingTensorHook( tensors=tensors_to_log, every_n_iter=100)
这是咱们添加进去的,加入了一个hooks的参数,让训练的时候没迭代100步就输出一次loss。然而这样的意义并非很大。
下面的日志能够看到验证时是加载训练完的模型来进行验证的,见下图第一行:Restoring xxx
这种没法在训练时输出验证集上的结果,会致使咱们很难直观的看到损失函数的变化。就没法很方便的肯定模型是否收敛,这也是tensorflow中这些高级API的问题,高级封装虽然让书写代码更容易,但也让代码更死板。
bert的其余应用在NLP-Project中的pre_trained_model中,包括bert+bilstm+crf作命名实体识别,bert+cnn作文本分类。