"Learning Efficient Object Detection Models with Knowledge Distillation"这篇文章经过知识蒸馏(Knowledge Distillation)与Hint指导学习(Hint Learning),提高了主干精简的多分类目标检测网络的推理精度(文章以Faster RCNN为例),例如Faster RCNN-Alexnet、Faster-RCNN-VGGM等,具体框架以下图所示:python
教师网络的暗知识提取分为三点:中间层Feature Maps的Hint;RPN/RCN中分类层的暗知识;以及RPN/RCN中回归层的暗知识。具体以下:git
具体指导学生网络学习时,RPN与RCN的分类损失由分类层softmax输出与hard target的交叉熵loss、以及分类层softmax输出与soft target的交叉熵loss构成:github
因为检测器须要鉴别的不一样类别之间存在样本不均衡(imbalance),所以在L_soft中须要对不一样类别的交叉熵分配不一样的权重,其中背景类的权重为1.5(较大的比例),其余分类的权重均为1.0:网络
RPN与RCN的回归损失由正常的smooth L1 loss、以及文章所定义的teacher bounded regression loss构成:框架
其中Ls_L1表示正常的smooth L1 loss,Lb表示文章定义的teacher bounded regression loss。当学生网络的位置回归与ground truth的L2距离超过教师网络的位置回归与ground truth的L2距离、且大于某一阈值时,Lb取学生网络的位置回归与ground truth之间的L2距离,不然Lb置0。ide
Hint learning须要计算教师网络与学生网络中间层输出的Feature Maps之间的L2 loss,而且在学生网络中须要添加可学习的适配层(adaptation layer),以确保guided layer输出的Feature Maps与教师网络输出的Hint维度一致:学习
经过知识蒸馏、Hint指导学习,提高了精简网络的泛化性、并有助于加快收敛,最后取得了良好的实验结果,具体见文章实验部分。ui
以SSD为例,KD loss与Teacher bounded L2 loss设计以下:设计
# -*- coding: utf-8 -*- import torch import torch.nn as nn import torch.nn.functional as F from ..box_utils import match, log_sum_exp eps = 1e-5 def KL_div(p, q, pos_w, neg_w): p = p + eps q = q + eps log_p = p * torch.log(p / q) log_p[:,0] *= neg_w log_p[:,1:] *= pos_w return torch.sum(log_p) class MultiBoxLoss(nn.Module): def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target, cfg, use_gpu=True, neg_w=1.5, pos_w=1.0, Temp=1., reg_m=0.): super(MultiBoxLoss, self).__init__() self.use_gpu = use_gpu self.num_classes = num_classes # 21 self.threshold = overlap_thresh # 0.5 self.background_label = bkg_label # 0 self.encode_target = encode_target # False self.use_prior_for_matching = prior_for_matching # True self.do_neg_mining = neg_mining # True self.negpos_ratio = neg_pos # 3 self.neg_overlap = neg_overlap # 0.5 self.variance = cfg['variance'] # soft-target loss self.neg_w = neg_w self.pos_w = pos_w self.Temp = Temp self.reg_m = reg_m def forward(self, predictions, pred_t, targets): """Multibox Loss Args: predictions (tuple): A tuple containing loc preds, conf preds, and prior boxes from SSD net. conf shape: torch.size(batch_size,num_priors,num_classes) loc shape: torch.size(batch_size,num_priors,4) priors shape: torch.size(num_priors,4) pred_t (tuple): teacher's predictions targets (tensor): Ground truth boxes and labels for a batch, shape: [batch_size,num_objs,5] (last idx is the label). """ loc_data, conf_data, priors = predictions num = loc_data.size(0) priors = priors[:loc_data.size(1), :] num_priors = (priors.size(0)) num_classes = self.num_classes # predictions of teachers loc_teach1, conf_teach1 = pred_t[0] # match priors (default boxes) and ground truth boxes loc_t = torch.Tensor(num, num_priors, 4) conf_t = torch.LongTensor(num, num_priors) for idx in range(num): truths = targets[idx][:, :-1].data labels = targets[idx][:, -1].data defaults = priors.data match(self.threshold, truths, defaults, self.variance, labels, loc_t, conf_t, idx) # wrap targets with torch.no_grad(): if self.use_gpu: loc_t = loc_t.cuda(non_blocking=True) conf_t = conf_t.cuda(non_blocking=True) pos = conf_t > 0 # (1, 0, 1, ...) num_pos = pos.sum(dim=1, keepdim=True) # [num, 1], number of positives # Localization Loss (Smooth L1) # Shape: [batch,num_priors,4] pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) # [batch,num_priors,1] before expand_as loc_p = loc_data[pos_idx].view(-1, 4) loc_t = loc_t[pos_idx].view(-1, 4) loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False) # knowledge transfer for loc regression # teach1 loc_teach1_p = loc_teach1[pos_idx].view(-1, 4) l2_dis_s = (loc_p - loc_t).pow(2).sum(1) l2_dis_s_m = l2_dis_s + self.reg_m l2_dis_t = (loc_teach1_p - loc_t).pow(2).sum(1) l2_num = l2_dis_s_m > l2_dis_t l2_loss_teach1 = l2_dis_s[l2_num].sum() l2_loss = l2_loss_teach1 # Compute max conf across batch for hard negative mining batch_conf = conf_data.view(-1, self.num_classes) loss_c = log_sum_exp(batch_conf.float()) - batch_conf.gather(1, conf_t.view(-1, 1)).float() # Hard Negative Mining loss_c[pos.view(-1, 1)] = 0 loss_c = loss_c.view(num, -1) #loss_c[pos] = 0 # filter out pos boxes for now _, loss_idx = loss_c.sort(1, descending=True) _, idx_rank = loss_idx.sort(1) num_pos = pos.long().sum(1, keepdim=True) num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1) neg = idx_rank < num_neg.expand_as(idx_rank) # Confidence Loss Including Positive and Negative Examples # CrossEntropy loss pos_idx = pos.unsqueeze(2).expand_as(conf_data) # [batch,num_priors,cls] neg_idx = neg.unsqueeze(2).expand_as(conf_data) conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes) targets_weighted = conf_t[(pos+neg).gt(0)] loss_c = F.cross_entropy(conf_p, targets_weighted, size_average=False) # soft loss for Knowledge Distillation # teach1 conf_p_teach = conf_teach1[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes) pt = F.softmax(conf_p_teach/self.Temp, dim=1) if self.neg_w > 1.: ps = F.softmax(conf_p/self.Temp, dim=1) soft_loss1 = KL_div(pt, ps, self.pos_w, self.neg_w) * (self.Temp**2) else: ps = F.log_softmax(conf_p/self.Temp, dim=1) soft_loss1 = nn.KLDivLoss(size_average=False)(ps, pt) * (self.Temp**2) soft_loss = soft_loss1 # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N N = num_pos.data.sum().float() loss_l = loss_l.float() loss_c = loss_c.float() loss_l /= N loss_c /= N l2_loss /= N soft_loss /= N return loss_l, loss_c, soft_loss, l2_loss
Paper地址:https://papers.nips.cc/paper/6676-learning-efficient-object-detection-models-with-knowledge-distillation.pdfcode
PyTorch版SSD:https://github.com/amdegroot/ssd.pytorch