这几天一直在用Pytorch来复现文本检测领域的CTPN论文,本文章将从数据处理、训练标签生成、神经网络搭建、损失函数设计、训练主过程编写等这几个方面来一步一步复现CTPN。CTPN算法理论能够参考这里。html
咱们的训练选择天池ICPR2018和MSRA_TD500两个数据集,天池ICPR的数据集为网络图像,都是一些淘宝商家上传到淘宝的一些商品介绍图像,其标签方式参考了ICDAR2015的数据标签格式,即一个文本框用4个坐标来表示,即左上、右上、右下、左下四个坐标,共八个值,记做[x1 y1 x2 y2 x3 y3 x4 y4]git
天池ICPR2018数据集的风格以下,字体形态格式颜色多变,多嵌套于物体之中,识别难度大:github
MSRA_TD500使微软收集的一个文本检测和识别的一个数据集,里面的图像可能是街景图,背景比较复杂,但文本位置比较明显,一目了然。由于MSRA_TD500的标签格式不同,最后一个参数表示矩形框的旋转角度。算法
因此咱们第一步就是将这两个数据集的标签格式统一,个人作法是将MSRA数据集格式改成ICDAR格式,方便后面的模型训练。由于MSRA_TD500采起的标签格式是[index difficulty_label x y w h angle],因此咱们须要根据这个文本框的旋转角度来求得水平文本框旋转后的4个坐标位置。实现以下:网络
""" This file is to change MSRA_TD500 dataset format to ICDAR2015 dataset format. MSRA_TD500 format: [index difficulty_label x y w h angle] ICDAR2015 format: [left_top_x left_top_y right_top_X right_top_y right_bottom_x right_bottom_y left_bottom_x left_bottom_y] """ import math import cv2 import os # 求旋转后矩形的4个坐标 def get_box_img(x, y, w, h, angle): # 矩形框中点(x0,y0) x0 = x + w/2 y0 = y + h/2 l = math.sqrt(pow(w/2, 2) + pow(h/2, 2)) # 即对角线的一半 # angle小于0,逆时针转 if angle < 0: a1 = -angle + math.atan(h / float(w)) # 旋转角度-对角线与底线所成的角度 a2 = -angle - math.atan(h / float(w)) # 旋转角度+对角线与底线所成的角度 pt1 = (x0 - l * math.cos(a2), y0 + l * math.sin(a2)) pt2 = (x0 + l * math.cos(a1), y0 - l * math.sin(a1)) pt3 = (x0 + l * math.cos(a2), y0 - l * math.sin(a2)) # x0+左下点旋转后在水平线上的投影, y0-左下点在垂直线上的投影,显然逆时针转时,左下点上一和左移了。 pt4 = (x0 - l * math.cos(a1), y0 + l * math.sin(a1)) else: a1 = angle + math.atan(h / float(w)) a2 = angle - math.atan(h / float(w)) pt1 = (x0 - l * math.cos(a1), y0 - l * math.sin(a1)) pt2 = (x0 + l * math.cos(a2), y0 + l * math.sin(a2)) pt3 = (x0 + l * math.cos(a1), y0 + l * math.sin(a1)) pt4 = (x0 - l * math.cos(a2), y0 - l * math.sin(a2)) return [pt1[0], pt1[1], pt2[0], pt2[1], pt3[0], pt3[1], pt4[0], pt4[1]] def read_file(path): result = [] for line in open(path): info = [] data = line.split(' ') info.append(int(data[2])) info.append(int(data[3])) info.append(int(data[4])) info.append(int(data[5])) info.append(float(data[6])) info.append(data[0]) result.append(info) return result if __name__ == '__main__': file_path = '/home/ljs/OCR_dataset/MSRA-TD500/test/' save_img_path = '../dataset/OCR_dataset/ctpn/test_im/' save_gt_path = '../dataset/OCR_dataset/ctpn/test_gt/' file_list = os.listdir(file_path) for f in file_list: if '.gt' in f: continue name = f[0:8] txt_path = file_path + name + '.gt' im_path = file_path + f im = cv2.imread(im_path) coordinate = read_file(txt_path) # 仿照ICDAR格式,图片名字写作img_xx.jpg,对应的标签文件写作gt_img_xx.txt cv2.imwrite(save_img_path + name.lower() + '.jpg', im) save_gt = open(save_gt_path + 'gt_' + name.lower() + '.txt', 'w') for i in coordinate: box = get_box_img(i[0], i[1], i[2], i[3], i[4]) box = [int(box[i]) for i in range(len(box))] box = [str(box[i]) for i in range(len(box))] save_gt.write(','.join(box)) save_gt.write('\n')
通过格式处理后,咱们两份数据集算是整理好了。固然咱们还须要对整个数据集划分为训练集和测试集,个人文件组织习惯以下:train_im, test_im文件夹装的是训练和测试图像,train_gt和test_gt装的是训练和测试标签。架构
由于CTPN的核心思想也是基于Faster RCNN中的region proposal机制的,因此原始数据标签须要转化为
anchor标签。训练数据的标签的生成的代码是最难写,由于从一个完整的文本框标签转化为一个个小尺度文本框标签确实有点难度,并且这个anchor标签的生成方式也与Faster RCNN生成方式略有不一样。下面讲一讲个人实现思路:app
第一步咱们须要将原先每张图的bbox标签转化为每一个anchor标签。为了实现该功能,咱们先将一张图划分为宽度为16的各个anchor。less
def generate_gt_anchor(img, box, anchor_width=16): """ calsulate ground truth fine-scale box :param img: input image :param box: ground truth box (4 point) :param anchor_width: :return: tuple (position, h, cy) """ if not isinstance(box[0], float): box = [float(box[i]) for i in range(len(box))] result = [] # 求解一个bbox下,能分解为多少个16宽度的小anchor,并求出最左和最右的小achor的id left_anchor_num = int(math.floor(max(min(box[0], box[6]), 0) / anchor_width)) # the left side anchor of the text box, downwards right_anchor_num = int(math.ceil(min(max(box[2], box[4]), img.shape[1]) / anchor_width)) # the right side anchor of the text box, upwards # handle extreme case, the right side anchor may exceed the image width if right_anchor_num * 16 + 15 > img.shape[1]: right_anchor_num -= 1 # combine the left-side and the right-side x_coordinate of a text anchor into one pair position_pair = [(i * anchor_width, (i + 1) * anchor_width - 1) for i in range(left_anchor_num, right_anchor_num)] # 计算每一个gt anchor的真实位置,其实就是求解gt anchor的上边界和下边界 y_top, y_bottom = cal_y_top_and_bottom(img, position_pair, box) # 最后将每一个anchor的位置(水平ID)、anchor中心y坐标、anchor高度存储并返回 for i in range(len(position_pair)): position = int(position_pair[i][0] / anchor_width) # the index of anchor box h = y_bottom[i] - y_top[i] + 1 # the height of anchor box cy = (float(y_bottom[i]) + float(y_top[i])) / 2.0 # the center point of anchor box result.append((position, cy, h)) return result
计算anchor上下边界的方法:dom
# cal the gt anchor box's bottom and top coordinate def cal_y_top_and_bottom(raw_img, position_pair, box): """ :param raw_img: :param position_pair: for example:[(0, 15), (16, 31), ...] :param box: gt box (4 point) :return: top and bottom coordinates for y-axis """ img = copy.deepcopy(raw_img) y_top = [] y_bottom = [] height = img.shape[0] # 设置图像mask,channel 0为全黑图 for i in range(img.shape[0]): for j in range(img.shape[1]): img[i, j, 0] = 0 top_flag = False bottom_flag = False # 根据bbox四点画出文本框,channel 0下文本框为白色 img = other.draw_box_4pt(img, box, color=(255, 0, 0)) for k in range(len(position_pair)): # 从左到右遍历anchor gt,对每一个anchor从上往下扫描像素,遇到白色像素点(255)就停下来,此时像素点坐标y就是该anchor gt的上边界 # calc top y coordinate for y in range(0, height-1): # loop each anchor, from left to right for x in range(position_pair[k][0], position_pair[k][1] + 1): if img[y, x, 0] == 255: y_top.append(y) top_flag = True break if top_flag is True: break # 从左到右遍历anchor gt,对每一个anchor从下往上扫描像素,遇到白色像素点(255)就停下来,此时像素点坐标y就是该anchor gt的下边界 # calc bottom y coordinate, pixel from down to top loop for y in range(height - 1, -1, -1): # loop each anchor, from left to right for x in range(position_pair[k][0], position_pair[k][1] + 1): if img[y, x, 0] == 255: y_bottom.append(y) bottom_flag = True break if bottom_flag is True: break top_flag = False bottom_flag = False return y_top, y_bottom
通过上面的标签处理,咱们已经将原先的标准的文本框标签转化为一个一个小尺度anchor标签,如下是标签转化后的效果:ide
以上标签可视化后看来anchor标签作得不错,可是这里须要提出的是,我发现这种anchor生成方法是不太精准的,好比一个文本框边缘像素恰好落在一个新的anchor上,那么咱们就要为这个像素分配一个16像素的anchor,显然致使了文本框标签的不许确,引入了15像素的偏差,这个是须要思考的。这个问题咱们先不作处理,继续下面的工做。
固然转化期间咱们也遇到不少奇怪的问题,好比下图这种标签都已经超出图像范围的,咱们必须作相应的特殊处理,好比限定标签横坐标的最大尺寸为图像宽度。
left_anchor_num = int(math.floor(max(min(box[0], box[6]), 0) / anchor_width)) # the left side anchor of the text box, downwards right_anchor_num = int(math.ceil(min(max(box[2], box[4]), img.shape[1]) / anchor_width)) # the right side anchor of the text box, upwards
由于CTPN用到了CNN+双向LSTM的网络结构,因此咱们分步实现CTPN架构。
CNN部分CTPN采起了VGG16进行底层特征提取。
class VGG_16(nn.Module): """ VGG-16 without pooling layer before fc layer """ def __init__(self): super(VGG_16, self).__init__() self.convolution1_1 = nn.Conv2d(3, 64, 3, padding=1) self.convolution1_2 = nn.Conv2d(64, 64, 3, padding=1) self.pooling1 = nn.MaxPool2d(2, stride=2) self.convolution2_1 = nn.Conv2d(64, 128, 3, padding=1) self.convolution2_2 = nn.Conv2d(128, 128, 3, padding=1) self.pooling2 = nn.MaxPool2d(2, stride=2) self.convolution3_1 = nn.Conv2d(128, 256, 3, padding=1) self.convolution3_2 = nn.Conv2d(256, 256, 3, padding=1) self.convolution3_3 = nn.Conv2d(256, 256, 3, padding=1) self.pooling3 = nn.MaxPool2d(2, stride=2) self.convolution4_1 = nn.Conv2d(256, 512, 3, padding=1) self.convolution4_2 = nn.Conv2d(512, 512, 3, padding=1) self.convolution4_3 = nn.Conv2d(512, 512, 3, padding=1) self.pooling4 = nn.MaxPool2d(2, stride=2) self.convolution5_1 = nn.Conv2d(512, 512, 3, padding=1) self.convolution5_2 = nn.Conv2d(512, 512, 3, padding=1) self.convolution5_3 = nn.Conv2d(512, 512, 3, padding=1) def forward(self, x): x = F.relu(self.convolution1_1(x), inplace=True) x = F.relu(self.convolution1_2(x), inplace=True) x = self.pooling1(x) x = F.relu(self.convolution2_1(x), inplace=True) x = F.relu(self.convolution2_2(x), inplace=True) x = self.pooling2(x) x = F.relu(self.convolution3_1(x), inplace=True) x = F.relu(self.convolution3_2(x), inplace=True) x = F.relu(self.convolution3_3(x), inplace=True) x = self.pooling3(x) x = F.relu(self.convolution4_1(x), inplace=True) x = F.relu(self.convolution4_2(x), inplace=True) x = F.relu(self.convolution4_3(x), inplace=True) x = self.pooling4(x) x = F.relu(self.convolution5_1(x), inplace=True) x = F.relu(self.convolution5_2(x), inplace=True) x = F.relu(self.convolution5_3(x), inplace=True) return x
再实现双向LSTM,加强关联序列的信息学习。
class BLSTM(nn.Module): def __init__(self, channel, hidden_unit, bidirectional=True): """ :param channel: lstm input channel num :param hidden_unit: lstm hidden unit :param bidirectional: """ super(BLSTM, self).__init__() self.lstm = nn.LSTM(channel, hidden_unit, bidirectional=bidirectional) def forward(self, x): """ WARNING: The batch size of x must be 1. """ x = x.transpose(1, 3) recurrent, _ = self.lstm(x[0]) recurrent = recurrent[np.newaxis, :, :, :] recurrent = recurrent.transpose(1, 3) return recurrent
这里实现多一层中间层,用于链接CNN和LSTM。将VGG最后一层卷积层输出的feature map转化为向量形式,用于接下来的LSTM训练。
class Im2col(nn.Module): def __init__(self, kernel_size, stride, padding): super(Im2col, self).__init__() self.kernel_size = kernel_size self.stride = stride self.padding = padding def forward(self, x): height = x.shape[2] x = F.unfold(x, self.kernel_size, padding=self.padding, stride=self.stride) x = x.reshape((x.shape[0], x.shape[1], height, -1)) return x
最后将以上三部分拼接成一个完整的CTPN网络:底层使用VGG16作特征提取->lstm序列信息学习->output每一个anchor分数,h, y, side_refinement
class CTPN(nn.Module): def __init__(self): super(CTPN, self).__init__() self.cnn = nn.Sequential() self.cnn.add_module('VGG_16', VGG_16()) self.rnn = nn.Sequential() self.rnn.add_module('im2col', Net.Im2col((3, 3), (1, 1), (1, 1))) self.rnn.add_module('blstm', BLSTM(3 * 3 * 512, 128)) self.FC = nn.Conv2d(256, 512, 1) self.vertical_coordinate = nn.Conv2d(512, 2 * 10, 1) # 最终输出2K个参数(k=10),10表示anchor的尺寸个数,2个参数分别表示anchor的h和dy self.score = nn.Conv2d(512, 2 * 10, 1) # 最终输出是2K个分数(k=10),2表示有无字符,10表示anchor的尺寸个数 self.side_refinement = nn.Conv2d(512, 10, 1) # 最终输出1K个参数(k=10),该参数表示该anchor的水平偏移,用于精修文本框水平边缘精度,,10表示anchor的尺寸个数 def forward(self, x, val=False): x = self.cnn(x) x = self.rnn(x) x = self.FC(x) x = F.relu(x, inplace=True) vertical_pred = self.vertical_coordinate(x) score = self.score(x) if val: score = score.reshape((score.shape[0], 10, 2, score.shape[2], score.shape[3])) score = score.squeeze(0) score = score.transpose(1, 2) score = score.transpose(2, 3) score = score.reshape((-1, 2)) #score = F.softmax(score, dim=1) score = score.reshape((10, vertical_pred.shape[2], -1, 2)) vertical_pred = vertical_pred.reshape((vertical_pred.shape[0], 10, 2, vertical_pred.shape[2], vertical_pred.shape[3])) side_refinement = self.side_refinement(x) return vertical_pred, score, side_refinement
CTPN的LOSS分为三部分:
先定义好一些固定参数
class CTPN_Loss(nn.Module): def __init__(self, using_cuda=False): super(CTPN_Loss, self).__init__() self.Ns = 128 self.ratio = 0.5 self.lambda1 = 1.0 self.lambda2 = 1.0 self.Ls_cls = nn.CrossEntropyLoss() self.Lv_reg = nn.SmoothL1Loss() self.Lo_reg = nn.SmoothL1Loss() self.using_cuda = using_cuda
首先设计classification loss
cls_loss = 0.0 if self.using_cuda: for p in positive_batch: cls_loss += self.Ls_cls(score[0, p[2] * 2: ((p[2] + 1) * 2), p[1], p[0]].unsqueeze(0), torch.LongTensor([1]).cuda()) for n in negative_batch: cls_loss += self.Ls_cls(score[0, n[2] * 2: ((n[2] + 1) * 2), n[1], n[0]].unsqueeze(0), torch.LongTensor([0]).cuda()) else: for p in positive_batch: cls_loss += self.Ls_cls(score[0, p[2] * 2: ((p[2] + 1) * 2), p[1], p[0]].unsqueeze(0), torch.LongTensor([1])) for n in negative_batch: cls_loss += self.Ls_cls(score[0, n[2] * 2: ((n[2] + 1) * 2), n[1], n[0]].unsqueeze(0), torch.LongTensor([0])) cls_loss = cls_loss / self.Ns
而后是vertical coordinate regression loss,反映的是y和h的误差
# calculate vertical coordinate regression loss v_reg_loss = 0.0 Nv = len(vertical_reg) if self.using_cuda: for v in vertical_reg: v_reg_loss += self.Lv_reg(vertical_pred[0, v[2] * 2: ((v[2] + 1) * 2), v[1], v[0]].unsqueeze(0), torch.FloatTensor([v[3], v[4]]).unsqueeze(0).cuda()) else: for v in vertical_reg: v_reg_loss += self.Lv_reg(vertical_pred[0, v[2] * 2: ((v[2] + 1) * 2), v[1], v[0]].unsqueeze(0), torch.FloatTensor([v[3], v[4]]).unsqueeze(0)) v_reg_loss = v_reg_loss / float(Nv)
最后计算side refinement regression loss,用于修正边缘精度
# calculate side refinement regression loss o_reg_loss = 0.0 No = len(side_refinement_reg) if self.using_cuda: for s in side_refinement_reg: o_reg_loss += self.Lo_reg(side_refinement[0, s[2]: s[2] + 1, s[1], s[0]].unsqueeze(0), torch.FloatTensor([s[3]]).unsqueeze(0).cuda()) else: for s in side_refinement_reg: o_reg_loss += self.Lo_reg(side_refinement[0, s[2]: s[2] + 1, s[1], s[0]].unsqueeze(0), torch.FloatTensor([s[3]]).unsqueeze(0)) o_reg_loss = o_reg_loss / float(No)
固然最后还有个total loss,汇总整个训练过程当中的loss
loss = cls_loss + v_reg_loss * self.lambda1 + o_reg_loss * self.lambda2
训练:优化器咱们选择SGD,learning rate咱们设置了两个,前N个epoch使用较大的lr,后面的epoch使用较小的lr以更好地收敛。训练过程咱们定义了4个loss,分别是total_cls_loss,total_v_reg_loss, total_o_reg_loss, total_loss(前面三个loss相加)。
net = Net.CTPN() # 获取网络结构 for name, value in net.named_parameters(): if name in no_grad: value.requires_grad = False else: value.requires_grad = True # for name, value in net.named_parameters(): # print('name: {0}, grad: {1}'.format(name, value.requires_grad)) net.load_state_dict(torch.load('./lib/vgg16.model')) # net.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) lib.utils.init_weight(net) if using_cuda: net.cuda() net.train() print(net) criterion = Loss.CTPN_Loss(using_cuda=using_cuda) # 获取loss train_im_list, train_gt_list, val_im_list, val_gt_list = create_train_val() # 获取训练、测试数据 total_iter = len(train_im_list) print("total training image num is %s" % len(train_im_list)) print("total val image num is %s" % len(val_im_list)) train_loss_list = [] test_loss_list = [] # 开始迭代训练 for i in range(epoch): if i >= change_epoch: lr = lr_behind else: lr = lr_front optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) #optimizer = optim.Adam(net.parameters(), lr=lr) iteration = 1 total_loss = 0 total_cls_loss = 0 total_v_reg_loss = 0 total_o_reg_loss = 0 start_time = time.time() random.shuffle(train_im_list) # 打乱训练集 # print(random_im_list) for im in train_im_list: root, file_name = os.path.split(im) root, _ = os.path.split(root) name, _ = os.path.splitext(file_name) gt_name = 'gt_' + name + '.txt' gt_path = os.path.join(root, "train_gt", gt_name) if not os.path.exists(gt_path): print('Ground truth file of image {0} not exists.'.format(im)) continue gt_txt = lib.dataset_handler.read_gt_file(gt_path) # 读取对应的标签 #print("processing image %s" % os.path.join(img_root1, im)) img = cv2.imread(im) if img is None: iteration += 1 continue img, gt_txt = lib.dataset_handler.scale_img(img, gt_txt) # 图像和标签作归一化 tensor_img = img[np.newaxis, :, :, :] tensor_img = tensor_img.transpose((0, 3, 1, 2)) if using_cuda: tensor_img = torch.FloatTensor(tensor_img).cuda() else: tensor_img = torch.FloatTensor(tensor_img) vertical_pred, score, side_refinement = net(tensor_img) # 正向计算,获取预测结果 del tensor_img # transform bbox gt to anchor gt for training positive = [] negative = [] vertical_reg = [] side_refinement_reg = [] visual_img = copy.deepcopy(img) # 该图用于可视化标签 try: # loop all bbox in one image for box in gt_txt: # generate anchors from one bbox gt_anchor, visual_img = lib.generate_gt_anchor.generate_gt_anchor(img, box, draw_img_gt=visual_img) # 获取图像的anchor标签 positive1, negative1, vertical_reg1, side_refinement_reg1 = lib.tag_anchor.tag_anchor(gt_anchor, score, box) # 计算预测值反映在anchor层面的数据 positive += positive1 negative += negative1 vertical_reg += vertical_reg1 side_refinement_reg += side_refinement_reg1 except: print("warning: img %s raise error!" % im) iteration += 1 continue if len(vertical_reg) == 0 or len(positive) == 0 or len(side_refinement_reg) == 0: iteration += 1 continue cv2.imwrite(os.path.join(DRAW_PREFIX, file_name), visual_img) optimizer.zero_grad() # 计算偏差 loss, cls_loss, v_reg_loss, o_reg_loss = criterion(score, vertical_pred, side_refinement, positive, negative, vertical_reg, side_refinement_reg) # 反向传播 loss.backward() optimizer.step() iteration += 1 # save gpu memory by transferring loss to float total_loss += float(loss) total_cls_loss += float(cls_loss) total_v_reg_loss += float(v_reg_loss) total_o_reg_loss += float(o_reg_loss) if iteration % display_iter == 0: end_time = time.time() total_time = end_time - start_time print('Epoch: {2}/{3}, Iteration: {0}/{1}, loss: {4}, cls_loss: {5}, v_reg_loss: {6}, o_reg_loss: {7}, {8}'. format(iteration, total_iter, i, epoch, total_loss / display_iter, total_cls_loss / display_iter, total_v_reg_loss / display_iter, total_o_reg_loss / display_iter, im)) logger.info('Epoch: {2}/{3}, Iteration: {0}/{1}'.format(iteration, total_iter, i, epoch)) logger.info('loss: {0}'.format(total_loss / display_iter)) logger.info('classification loss: {0}'.format(total_cls_loss / display_iter)) logger.info('vertical regression loss: {0}'.format(total_v_reg_loss / display_iter)) logger.info('side-refinement regression loss: {0}'.format(total_o_reg_loss / display_iter)) train_loss_list.append(total_loss) total_loss = 0 total_cls_loss = 0 total_v_reg_loss = 0 total_o_reg_loss = 0 start_time = time.time() # 按期验证模型性能 if iteration % val_iter == 0: net.eval() logger.info('Start evaluate at {0} epoch {1} iteration.'.format(i, iteration)) val_loss = evaluate.val(net, criterion, val_batch_size, using_cuda, logger, val_im_list) logger.info('End evaluate.') net.train() start_time = time.time() test_loss_list.append(val_loss) # 按期存储模型 if iteration % save_iter == 0: print('Model saved at ./model/ctpn-{0}-{1}.model'.format(i, iteration)) torch.save(net.state_dict(), './model/ctpn-msra_ali-{0}-{1}.model'.format(i, iteration)) print('Model saved at ./model/ctpn-{0}-end.model'.format(i)) torch.save(net.state_dict(), './model/ctpn-msra_ali-{0}-end.model'.format(i)) # 画出loss的变化图 draw_loss_plot(train_loss_list, test_loss_list)
缩放图像具备必定规则:首先要保证文本框label的最短边也要等于600。咱们经过scale = float(shortest_side)/float(min(height, width))
来求得图像的缩放系数,对原始图像进行缩放。同时咱们也要对咱们的label也要根据该缩放系数进行缩放。
def scale_img(img, gt, shortest_side=600): height = img.shape[0] width = img.shape[1] scale = float(shortest_side)/float(min(height, width)) img = cv2.resize(img, (0, 0), fx=scale, fy=scale) if img.shape[0] < img.shape[1] and img.shape[0] != 600: img = cv2.resize(img, (600, img.shape[1])) elif img.shape[0] > img.shape[1] and img.shape[1] != 600: img = cv2.resize(img, (img.shape[0], 600)) elif img.shape[0] != 600: img = cv2.resize(img, (600, 600)) h_scale = float(img.shape[0])/float(height) w_scale = float(img.shape[1])/float(width) scale_gt = [] for box in gt: scale_box = [] for i in range(len(box)): # x坐标 if i % 2 == 0: scale_box.append(int(int(box[i]) * w_scale)) # y坐标 else: scale_box.append(int(int(box[i]) * h_scale)) scale_gt.append(scale_box) return img, scale_gt
验证集评估:
def val(net, criterion, batch_num, using_cuda, logger): img_root = '../dataset/OCR_dataset/ctpn/test_im' gt_root = '../dataset/OCR_dataset/ctpn/test_gt' img_list = os.listdir(img_root) total_loss = 0 total_cls_loss = 0 total_v_reg_loss = 0 total_o_reg_loss = 0 start_time = time.time() for im in random.sample(img_list, batch_num): name, _ = os.path.splitext(im) gt_name = 'gt_' + name + '.txt' gt_path = os.path.join(gt_root, gt_name) if not os.path.exists(gt_path): print('Ground truth file of image {0} not exists.'.format(im)) continue gt_txt = Dataset.port.read_gt_file(gt_path, have_BOM=True) img = cv2.imread(os.path.join(img_root, im)) img, gt_txt = Dataset.scale_img(img, gt_txt) tensor_img = img[np.newaxis, :, :, :] tensor_img = tensor_img.transpose((0, 3, 1, 2)) if using_cuda: tensor_img = torch.FloatTensor(tensor_img).cuda() else: tensor_img = torch.FloatTensor(tensor_img) vertical_pred, score, side_refinement = net(tensor_img) del tensor_img positive = [] negative = [] vertical_reg = [] side_refinement_reg = [] for box in gt_txt: gt_anchor = Dataset.generate_gt_anchor(img, box) positive1, negative1, vertical_reg1, side_refinement_reg1 = Net.tag_anchor(gt_anchor, score, box) positive += positive1 negative += negative1 vertical_reg += vertical_reg1 side_refinement_reg += side_refinement_reg1 if len(vertical_reg) == 0 or len(positive) == 0 or len(side_refinement_reg) == 0: batch_num -= 1 continue loss, cls_loss, v_reg_loss, o_reg_loss = criterion(score, vertical_pred, side_refinement, positive, negative, vertical_reg, side_refinement_reg) total_loss += loss total_cls_loss += cls_loss total_v_reg_loss += v_reg_loss total_o_reg_loss += o_reg_loss end_time = time.time() total_time = end_time - start_time print('#################### Start evaluate ####################') print('loss: {0}'.format(total_loss / float(batch_num))) logger.info('Evaluate loss: {0}'.format(total_loss / float(batch_num))) print('classification loss: {0}'.format(total_cls_loss / float(batch_num))) logger.info('Evaluate vertical regression loss: {0}'.format(total_v_reg_loss / float(batch_num))) print('vertical regression loss: {0}'.format(total_v_reg_loss / float(batch_num))) logger.info('Evaluate side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num))) print('side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num))) logger.info('Evaluate side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num))) print('{1} iterations for {0} seconds.'.format(total_time, batch_num)) print('##################### Evaluate end #####################') print('\n')
训练过程:
测试效果:输入一张图片,给出最后的检测结果
def infer_one(im_name, net): im = cv2.imread(im_name) im = lib.dataset_handler.scale_img_only(im) # 归一化图像 img = copy.deepcopy(im) img = img.transpose(2, 0, 1) img = img[np.newaxis, :, :, :] img = torch.Tensor(img) v, score, side = net(img, val=True) # 送入网络预测 result = [] # 根据分数获取有文字的anchor for i in range(score.shape[0]): for j in range(score.shape[1]): for k in range(score.shape[2]): if score[i, j, k, 1] > THRESH_HOLD: result.append((j, k, i, float(score[i, j, k, 1].detach().numpy()))) # nms过滤 for_nms = [] for box in result: pt = lib.utils.trans_to_2pt(box[1], box[0] * 16 + 7.5, anchor_height[box[2]]) for_nms.append([pt[0], pt[1], pt[2], pt[3], box[3], box[0], box[1], box[2]]) for_nms = np.array(for_nms, dtype=np.float32) nms_result = lib.nms.cpu_nms(for_nms, NMS_THRESH) out_nms = [] for i in nms_result: out_nms.append(for_nms[i, 0:8]) # 肯定哪几个anchors是属于一组的 connect = get_successions(v, out_nms) # 将一组anchors合并成一条文本线 texts = get_text_lines(connect, im.shape) for box in texts: box = np.array(box) print(box) lib.draw_image.draw_ploy_4pt(im, box[0:8]) _, basename = os.path.split(im_name) cv2.imwrite('./infer_'+basename, im)
推断时提到了get_successions
用于获取一个预测文本行里的全部anchors,换句话说,咱们获得的不少预测有字符的anchor,可是咱们怎么知道哪些acnhors能够组成一个文本线呢?因此咱们须要实现一个anchor合并算法,这也是CTPN代码实现中最为困难的一步。
CTPN论文提到,文本线构造法以下:文本行构建很简单,经过将那些text/no-text score > 0.7的连续的text proposals相链接便可。文本行的构建以下。
一看理论很简单,可是一到本身实现就困难重重了。真是应了那句“纸上得来终觉浅,绝知此事要躬行”啊!get_successions
传入的参数是v表明每一个预测anchor的h和y信息,anchors表明每一个anchors的四个顶点坐标信息。
def get_successions(v, anchors=[]): texts = [] for i, anchor in enumerate(anchors): neighbours = [] # 记录每组的anchors neighbours.append(i) center_x1 = (anchor[2] + anchor[0]) / 2 h1 = get_anchor_h(anchor, v) # 获取该anchor的高度 # find i's neighbour # 遍历余下的anchors,找出邻居 for j in range(i + 1, len(anchors)): center_x2 = (anchors[j][2] + anchors[j][0]) / 2 # 中心点X坐标 h2 = get_anchor_h(anchors[j], v) # 若是这两个Anchor间的距离小于50,并且他们的它们的垂直重叠(vertical overlap)大于必定阈值,那就是邻居 if abs(center_x1 - center_x2) < NEIGHBOURS_MIN_DIST and \ meet_v_iou(max(anchor[1], anchors[j][1]), min(anchor[3], anchors[j][3]), h1, h2): # less than 50 pixel between each anchor neighbours.append(j) if len(neighbours) != 0: texts.append(neighbours) # 经过上面的步骤,咱们已经把每个anchor的邻居都找到并加入了对应的集合中了,如今咱们 # 经过一个循环来不断将每一个小组合并 need_merge = True while need_merge: need_merge = False # ok, we combine again. for i, line in enumerate(texts): if len(line) == 0: continue for index in line: for j in range(i+1, len(texts)): if index in texts[j]: texts[i] += texts[j] texts[i] = list(set(texts[i])) texts[j] = [] need_merge = True result = [] #print(texts) for text in texts: if len(text) < 2: continue local = [] for j in text: local.append(anchors[j]) result.append(local) return result
当咱们获得一个文本框的anchors组合后,接下来要作的就是将组内的anchors串联成一个文本框。get_text_lines
函数作的就是这个功能。
def get_text_lines(text_proposals, im_size, scores=0): """ text_proposals:boxes """ text_lines = np.zeros((len(text_proposals), 8), np.float32) for index, tp_indices in enumerate(text_proposals): text_line_boxes = np.array(tp_indices) # 每一个文本行的所有小框 #print(text_line_boxes) #print(type(text_line_boxes)) #print(text_line_boxes.shape) X = (text_line_boxes[:, 0] + text_line_boxes[:, 2]) / 2 # 求每个小框的中心x,y坐标 Y = (text_line_boxes[:, 1] + text_line_boxes[:, 3]) / 2 #print(X) #print(Y) z1 = np.polyfit(X, Y, 1) # 多项式拟合,根据以前求的中心店拟合一条直线(最小二乘) x0 = np.min(text_line_boxes[:, 0]) # 文本行x坐标最小值 x1 = np.max(text_line_boxes[:, 2]) # 文本行x坐标最大值 offset = (text_line_boxes[0, 2] - text_line_boxes[0, 0]) * 0.5 # 小框宽度的一半 # 以所有小框的左上角这个点去拟合一条直线,而后计算一下文本行x坐标的极左极右对应的y坐标 lt_y, rt_y = fit_y(text_line_boxes[:, 0], text_line_boxes[:, 1], x0 + offset, x1 - offset) # 以所有小框的左下角这个点去拟合一条直线,而后计算一下文本行x坐标的极左极右对应的y坐标 lb_y, rb_y = fit_y(text_line_boxes[:, 0], text_line_boxes[:, 3], x0 + offset, x1 - offset) #score = scores[list(tp_indices)].sum() / float(len(tp_indices)) # 求所有小框得分的均值做为文本行的均值 text_lines[index, 0] = x0 text_lines[index, 1] = min(lt_y, rt_y) # 文本行上端 线段 的y坐标的小值 text_lines[index, 2] = x1 text_lines[index, 3] = max(lb_y, rb_y) # 文本行下端 线段 的y坐标的大值 text_lines[index, 4] = scores # 文本行得分 text_lines[index, 5] = z1[0] # 根据中心点拟合的直线的k,b text_lines[index, 6] = z1[1] height = np.mean((text_line_boxes[:, 3] - text_line_boxes[:, 1])) # 小框平均高度 text_lines[index, 7] = height + 2.5 text_recs = np.zeros((len(text_lines), 9), np.float32) index = 0 for line in text_lines: b1 = line[6] - line[7] / 2 # 根据高度和文本行中心线,求取文本行上下两条线的b值 b2 = line[6] + line[7] / 2 x1 = line[0] y1 = line[5] * line[0] + b1 # 左上 x2 = line[2] y2 = line[5] * line[2] + b1 # 右上 x3 = line[0] y3 = line[5] * line[0] + b2 # 左下 x4 = line[2] y4 = line[5] * line[2] + b2 # 右下 disX = x2 - x1 disY = y2 - y1 width = np.sqrt(disX * disX + disY * disY) # 文本行宽度 fTmp0 = y3 - y1 # 文本行高度 fTmp1 = fTmp0 * disY / width x = np.fabs(fTmp1 * disX / width) # 作补偿 y = np.fabs(fTmp1 * disY / width) if line[5] < 0: x1 -= x y1 += y x4 += x y4 -= y else: x2 += x y2 += y x3 -= x y3 -= y # clock-wise order text_recs[index, 0] = x1 text_recs[index, 1] = y1 text_recs[index, 2] = x2 text_recs[index, 3] = y2 text_recs[index, 4] = x4 text_recs[index, 5] = y4 text_recs[index, 6] = x3 text_recs[index, 7] = y3 text_recs[index, 8] = line[4] index = index + 1 text_recs = clip_boxes(text_recs, im_size) return text_recs
首先看一下训练出来的模型的文字检测效果,为了便于观察,我把anchor和最终合并好的文本框一并画出:
下面再看看一些比较好的文字检测效果吧:
在实现过程当中的一些总结和想法:
CTPN的完整实现能够参考个人Github