tensorflow LSTM+CTC使用详解

  最近用tensorflow写了个OCR的程序,在实现的过程当中,发现本身仍是跳了很多坑,在这里作一个记录,便于之后回忆。主要的内容有lstm+ctc具体的输入输出,以及TF中的CTC和百度开源的warpCTC在具体使用中的区别。python

正文

输入输出

由于我最后要最小化的目标函数就是ctc_loss,因此下面就从如何构造输入输出提及。git

tf.nn.ctc_loss

先从TF自带的tf.nn.ctc_loss提及,官方给的定义以下,所以咱们须要作的就是将图片的label(须要OCR出的结果),图片,以及图片的长度转换为label,input,和sequence_length。github

ctc_loss(
labels,
inputs,
sequence_length,
preprocess_collapse_repeated=False,
ctc_merge_repeated=True,
time_major=True
)
input: 输入(训练)数据,是一个三维float型的数据结构 [max_time_step , batch_size , num_classes],当修改time_major = False时, [batch_size,max_time_step,num_classes]
整体的数据流:
image_batch
-> [batch_size,max_time_step,num_features]->lstm
-> [batch_size,max_time_step,cell.output_size]->reshape
-> [batch_size*max_time_step,num_hidden]->affine projection A*W+b
-> [batch_size*max_time_step,num_classes]->reshape
-> [batch_size,max_time_step,num_classes]->transpose
-> [max_time_step,batch_size,num_classes]
下面详细解释一下,
假如一张图片有以下shape:[60,160,3],咱们若是读取灰度图则shape=[60,160],此时,咱们将其一列做为feature,那么共有60个features,160个time_step,这时假设一个batch为64,那么咱们此时得到到了一个 [batch_size,max_time_step,num_features] = [64,160,60]的训练数据。
而后将该训练数据送入 构建的lstm网络中,(须要注意的是 dynamic_rnn的输入数据在一个batch内的长度是固定的,可是不一样batch之间能够不一样,咱们须要给他一个 sequence_length(长度为batch_size的向量)来记录本次batch数据的长度,对于OCR这个问题,sequence_length就是长度为64,而值为160的一维向量)
获得形如 [batch_size,max_time_step,cell.output_size]的输出,其中cell.output_size == num_hidden。
下面咱们须要作一个线性变换将其送入ctc_loos中进行计算,lstm中不一样time_step之间共享权值,因此咱们只需定义 W的结构为 [num_hidden,num_classes]b的结构为[num_classes]。而 tf.matmul操做中,两个矩阵相乘阶数应当匹配,因此咱们将上一步的输出reshape成 [batch_size*max_time_step,num_hidden](num_hidden为本身定义的lstm的unit个数)记为 A,而后将其作一个线性变换,因而 A*w+b获得形如 [batch_size*max_time_step,num_classes]而后在reshape回来获得 [batch_size,max_time_step,num_classes]最后因为ctc_loss的要求,咱们再作一次转置,获得 [max_time_step,batch_size,num_classes]形状的数据做为input

labels: 标签序列
因为OCR的结果是不定长的,因此label其实是一个稀疏矩阵SparseTensor
其中:api

  • indices:二维int64的矩阵,表明非0的坐标点
  • values:二维tensor,表明indice位置的数据值
  • dense_shape:一维,表明稀疏矩阵的大小
    好比有两幅图,分别是123,和4567那么
    indecs = [[0,0],[0,1],[0,2],[1,0],[1,1],[1,2],[1,3]]
    values = [1,2,3,4,5,6,7]
    dense_shape = [2,4]
    表明dense tensor:
    1
    2
    [[1,2,3,0]
    [4,5,6,7]]

seq_len: 在input一节中已经讲过,一维数据,[time_step,…,time_step]长度为batch_size,值为time_step网络

相关文章
相关标签/搜索