在tensorflow/nmt项目中,训练数据和推断数据的输入使用了新的Dataset API,应该是tensorflow 1.2以后引入的API,方便数据的操做。若是你还在使用老的Queue和Coordinator的方式,建议升级高版本的tensorflow而且使用Dataset API。python
本教程将从训练数据和推断数据两个方面,详解解析数据的具体处理过程,你将看到文本数据如何转化为模型所须要的实数,以及中间的张量的维度是怎么样的,batch_size和其余超参数又是如何做用的。git
先来看看训练数据的处理。训练数据的处理比推断数据的处理稍微复杂一些,弄懂了训练数据的处理过程,就能够很轻松地理解推断数据的处理。
训练数据的处理代码位于nmt/utils/iterator_utils.py文件内的get_iterator
函数。github
咱们先来看看这个函数所须要的参数是什么意思:网络
参数 | 解释 |
---|---|
src_dataset |
源数据集 |
tgt_dataset |
目标数据集 |
src_vocab_table |
源数据单词查找表,就是个单词和int类型数据的对应表 |
tgt_vocab_table |
目标数据单词查找表,就是个单词和int类型数据的对应表 |
batch_size |
批大小 |
sos |
句子开始标记 |
eos |
句子结尾标记 |
random_seed |
随机种子,用来打乱数据集的 |
num_buckets |
桶数量 |
src_max_len |
源数据最大长度 |
tgt_max_len |
目标数据最大长度 |
num_parallel_calls |
并发处理数据的并发数 |
output_buffer_size |
输出缓冲区大小 |
skip_count |
跳过数据行数 |
num_shards |
将数据集分片的数量,分布式训练中有用 |
shard_index |
数据集分片后的id |
reshuffle_each_iteration |
是否每次迭代都从新打乱顺序 |
上面的解释,若是有不清楚的,能够查看我以前一片介绍超参数的文章:
tensorflow_nmt的超参数详解并发
咱们首先搞清楚几个重要的参数是怎么来的。src_dataset
和tgt_dataset
是咱们的训练数据集,他们是逐行一一对应的。好比咱们有两个文件src_data.txt
和tgt_data.txt
分别对应训练数据的源数据和目标数据,那么它们的Dataset如何建立的呢?其实利用Dataset API很简单:app
src_dataset=tf.data.TextLineDataset('src_data.txt') tgt_dataset=tf.data.TextLineDataset('tgt_data.txt')
这就是上述函数中的两个参数src_dataset
和tgt_dataset
的由来。dom
src_vocab_table
和tgt_vocab_table
是什么呢?一样顾名思义,就是这两个分别表明源数据词典的查找表和目标数据词典的查找表,实际上查找表就是一个字符串到数字的映射关系。固然,若是咱们的源数据和目标数据使用的是同一个词典,那么这两个查找表的内容是如出一辙的。很容易想到,确定也有一种数字到字符串的映射表,这是确定的,由于神经网络的数据是数字,而咱们须要的目标数据是字符串,所以它们之间确定有一个转换的过程,这个时候,就须要咱们的reverse_vocab_table来做用了。分布式
咱们看看这两个表是怎么构建出来的呢?代码很简单,利用tensorflow库中定义的lookup_ops便可:函数
def create_vocab_tables(src_vocab_file, tgt_vocab_file, share_vocab): """Creates vocab tables for src_vocab_file and tgt_vocab_file.""" src_vocab_table = lookup_ops.index_table_from_file( src_vocab_file, default_value=UNK_ID) if share_vocab: tgt_vocab_table = src_vocab_table else: tgt_vocab_table = lookup_ops.index_table_from_file( tgt_vocab_file, default_value=UNK_ID) return src_vocab_table, tgt_vocab_table
咱们能够发现,建立这两个表的过程,就是将词典中的每个词,对应一个数字,而后返回这些数字的集合,这就是所谓的词典查找表。效果上来讲,就是对词典中的每个词,从0开始递增的分配一个数字给这个词。fetch
那么到这里你有可能会有疑问,咱们词典中的词和咱们自定义的标记sos
等是否是有可能被映射为同一个整数而形成冲突?这个问题该如何解决?聪明如你,这个问题是存在的。那么咱们的项目是如何解决的呢?很简单,那就是将咱们自定义的标记当成词典的单词,而后加入到词典文件中,这样一来,lookup_ops
操做就把标记当成单词处理了,也就就解决了冲突!
具体的过程,本文后面会有一个例子,能够为您呈现具体过程。
若是咱们指定了share_vocab
参数,那么返回的源单词查找表和目标单词查找表是同样的。咱们还能够指定一个default_value,在这里是UNK_ID
,实际上就是0
。若是不指定,那么默认值为-1
。这就是查找表的建立过程。若是你想具体的知道其代码实现,能够跳转到tensorflow的C++核心部分查看代码(使用PyCharm或者相似的IDE)。
该函数处理训练数据的主要代码以下:
if not output_buffer_size: output_buffer_size = batch_size * 1000 src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32) tgt_sos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(sos)), tf.int32) tgt_eos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(eos)), tf.int32) src_tgt_dataset = tf.data.Dataset.zip((src_dataset, tgt_dataset)) src_tgt_dataset = src_tgt_dataset.shard(num_shards, shard_index) if skip_count is not None: src_tgt_dataset = src_tgt_dataset.skip(skip_count) src_tgt_dataset = src_tgt_dataset.shuffle( output_buffer_size, random_seed, reshuffle_each_iteration) src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: ( tf.string_split([src]).values, tf.string_split([tgt]).values), num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) # Filter zero length input sequences. src_tgt_dataset = src_tgt_dataset.filter( lambda src, tgt: tf.logical_and(tf.size(src) > 0, tf.size(tgt) > 0)) if src_max_len: src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (src[:src_max_len], tgt), num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) if tgt_max_len: src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (src, tgt[:tgt_max_len]), num_parallel_calls=num_parallel_calls)