Spatially Adaptive Residual Networks for Efficient Image and Video Deblurring

paper python

本文是印度马德拉斯理工学院的研究员提出的一种基于空间自适应残差网络的图像/视频去模糊方法。网络

严重模糊图像复原要求网络具备极大感觉野,现有网络每每采用加深网络层数、加大卷积核尺寸或者多尺度方式提高感觉野,然而这些方法会早知模型大小的提高以及推理耗时提高。做者提出一种组合形变卷积与自注意力机制的去模糊网络,进一步,集成时序递归模块能够将其扩展到视频去模糊。该网络能够模拟空间可变模糊移除而无需多尺度与大卷积核。最后做者经过实验定性与定量进行分析:在速度、精度以及模型大小方面均取得了SOTA性能。架构

Abstract

​ 针对已有去模糊方法存在的两个局限性:(1) 空间不变卷积核,对于动态场景去模糊而言并不是最优方案,严重限制了去模糊精度;(2) 经过网络深度与卷积核尺寸提高扩大感觉野,这会致使模型变大、推理耗时增长。框架

​ 为此,基于形变卷积与自注意力机制,做者提出一种高效果的端到端的去模糊框架。它与其余SOTA方法的性能对比见下图。ide

​ 该方法的优势包含如下几点:函数

  • 全卷积且参数高效,仅需一次性前向过程;
  • 可轻易集成其余架构与损失函数;
  • 网络估计的变换是动态的,于是能够自适应处理测试图像。

Method

​ 上图给出了做者所提出的SARN网络架构示意图,其中编码子网络将输入图像逐渐变换为分辨率更小、通道更多的特征图,在此基础上继续执行空间注意力模块与形变残差模块,最后送入到解码模块中,经过一系列的残差模块与反卷积对其进行重建。注:上图中n=32。性能

Deformable Residual Module

​ 传统的CNN在固定的网格上进行采样,这限制了其模拟未知几何变换的能力。STN将空间学习引入到CNN中,然而这种变换比较耗时且为全局图像变换,并不适合于局部图像几何变换。做者采用形变卷积,它以一种有效的方法学习局部几何变换。形变卷积首先学习稠密偏移图进行特征重采样,而后再进行卷积操做,该过程见上图中的从Input FeatureOutput Feature的过程。做者在形变卷积基础上引入参考模块,称之为形变残差模块。更多关于形变卷积的介绍与分析建议参考原文Deformable Convolutional Networks学习

Self-Attention Module

​ 近期的去模糊方法着重于多尺度处理,这种处理方式能够获取不一样尺度的运动模糊,提高网络的感觉野。尽管这种“自粗而精”的处理策略能够处理不一样程度的模糊,可是它没法从全局角度利用模糊区域之间的相关性,而这对于复原任务也很重要。为此,做者提出采用:在不一样空间分辨率利用注意力机制学习非局部关联性。测试

​ 用于模拟长范围依赖关系的注意力机制已在多个领域(跨语言与视觉应用)取得了成功。做者采用非局部注意力进行不一样场景区域之间的关联性学习并用于提高图像复原质量。优化

​ 上图给出了做者所提出的SAM模块示意图。它有以下两点优点:

  • 它克服了感觉野有限的局限性;
  • 它隐含的提供了一种能够传播相对信息的通路。

上述优点使得它适合于处理去模糊,这是由于:因模糊致使的场景-边缘之间每每是相关的。

​ 以上图为例,给定输入特征A\in R^{C \times H \times W},首先,将其送入两个1\times1卷积获得两个新的特征B和C,其中\{B, C\} \in R^{\hat{C} \times H \times W};而后,将其进行reshape为R^{\hat{C} \times N};其次,对B和C进行矩阵乘操做并执行softmax获得空间注意力特征S \in R^{N \times N}(s_{ji}能够度量i位置与j位置的影响关系),计算方式以下公式所示。最后,将特征A经由另外一个1\times1卷积获得特征D\in R^{C\times H \times W},并reshape为R^{C \times N},并将其S进行矩阵乘操做获得加强版特征,将其与特征A相加获得最终的特征E\in R^{C\times H \times W}

s_{ji} = \frac{\mathcal{exp}(B_i \cdot C_j)}{\sum_{i=1}^N \mathcal{exp}(B_i \cdot C_j)} \notag

​ 经由上述操做获得的特征E包含全部位置特征的加权组合以及原始特征。所以它具备全局上下文信息,并按照空间注意力进行上下文信息选择性集成,促使类似特征加强,不相关特征削弱。

​ 做者还发现:将SAM至于DRM以前能够取得更好的性能。猜想缘由为:早期的特征加强有助于提高网络的非局部性。

Video Deblurring

​ 图像去模糊一种很天然的扩展是视频去模糊,做者采用LSTM进行先后帧特征集成,该过程能够描述为:

\begin{split}
f^i &= Net_E(B^i, I^{i-1})  \\
h^i, g^i &= ConvLSTM(h^{i-1}, f^i; \theta_{LSTM})  \\
I^i &= Net_D(g^i; \theta_D)
\end{split}

在视频去模糊中,它以5帧做为输入,输出中间帧的去模糊效果图。

Experiments

​ 在训练过程当中,相关参数配置以下:

  • 对于图像去模糊任务,训练数据为GoPro,优化器为Adam,学习率为0.0001,BatchSize=4,训练迭代次数为1百万.
  • 对于视频去模糊任务,优化器Adam,学习率0.0001,BatchSize=4,迭代次数3百万。

​ 下图给出在GoPro数据集上相关去模糊方法的性能与视觉效果对比。更多实验结果与分析建议参考原文,这里再也不赘述。

​ 下面给出了在视频去模糊任务上的性能与视觉效果对比。更多实验结果与分析建议参考原文,这里再也不赘述。

Concolusion

​ 做者结合形变卷积、自注意力机制提出一种有效的图像/视频去模糊方法。其中形变卷积残差模块能够解决局部模糊的局部信息偏移问题;而自注意力机制则能够对不一样模糊区域创建关联性,从而提高特征性能。自注意力机制与形变卷积都可提高网络的感觉野,同时具备高效性。最后做者经过实验验证了所提方法的SOTA性能。

参考代码

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.ops import DeformConvPack

# GPU 
os.environ["CUDA_VISIBLE_DEVICES"]="0"

# DeformConv copy from mmdetection.
class DeformResModule(nn.Module):
    def __init__(self, inc, ksize):
        super(DeformResModule, self).__init__()
        pad = (ksize-1)//2
        self.dconv = DeformConvPack(inc,inc,ksize,1,padding=pad)
    def forward(self, x):
        res = self.dconv(x)
        return res + x

class ResBlock(nn.Module):
    def __init__(self, inc, ksize):
        super(ResBlock, self).__init__()
        padding = (ksize-1)//2
        self.conv1 = nn.Conv2d(inc, inc, ksize, 1, padding)
        self.conv2 = nn.Conv2d(inc, inc, ksize, 1, padding)
    def forward(self, x):
        res = self.conv2(F.relu(self.conv1(x)))
        return res + x
        
class SAM(nn.Module):
    def __init__(self, inc):
        super(SAM, self).__init__()
        self.convb = nn.Conv2d(inc, inc, 1)
        self.convc = nn.Conv2d(inc, inc, 1)
        self.convd = nn.Conv2d(inc, inc, 1)
        
    def forward(self, x):
        N, C, H, W = x.size()
        featB = self.convb(x)                         #N,C,H,W
        featC = self.convc(x)                         #N,C,H,W
        featD = self.convd(x)                         #N,C,H,W
        
        featB = featB.reshape(N, C, -1)               #N,C, HW
        featC = featC.reshape(N, C, -1)               #N,C, HW
        featC = featC.permute(0, 2, 1)                #N,HW,C
        
        featD = featD.reshape(N, C, -1)               #N,C, HW
        featD = featD.permute(0, 2, 1)                #N,HW,C
        
        featBC = torch.matmul(featC, featB)           #N,HW,HW
        featBC = featBC.softmax(-1)                   #N,HW,HW
        
        fusion = torch.matmul(featBC, featD)          #N,HW,C
        fusion = fusion.permute(0, 2, 1).contiguous() #N,C, HW
        fusion = fusion.reshape(N, C, H, W)           #N,C,H,W
        
        return x + fusion

class Net(nn.Module):
    def __init__(self, inc, outc, midc):
        super(Net, self).__init__()
        mid2 = midc*2
        mid4 = midc*4
        self.ecode1 = nn.Sequential(nn.Conv2d(inc,midc,3,1,1),
                                    nn.ReLU(),
                                    ResBlock(midc, 3),
                                    ResBlock(midc, 3),
                                    ResBlock(midc, 3))
        self.ecode2 = nn.Sequential(nn.Conv2d(midc,mid2,3,2,1),
                                    nn.ReLU(),
                                    ResBlock(mid2, 3),
                                    ResBlock(mid2, 3),
                                    ResBlock(mid2, 3))
        self.ecode3 = nn.Sequential(nn.Conv2d(mid2,mid4,3,2,1),
                                    nn.ReLU(),
                                    SAM(mid4),
                                    DeformResModule(mid4, 3),
                                    DeformResModule(mid4, 3),
                                    DeformResModule(mid4, 3),
                                    DeformResModule(mid4, 3),
                                    DeformResModule(mid4, 3),
                                    DeformResModule(mid4, 3))
        
        self.dcode2 = nn.Sequential(ResBlock(mid2, 3),
                                    ResBlock(mid2, 3),
                                    ResBlock(mid2, 3))
        self.dcode1 = nn.Sequential(ResBlock(midc, 3),
                                    ResBlock(midc, 3),
                                    ResBlock(midc, 3),
                                    nn.Conv2d(midc, 3, 3, 1, 1))
                
        self.upsample1 = nn.ConvTranspose2d(mid4, mid2, 4, 2, 1)
        self.upsample2 = nn.ConvTranspose2d(mid2, midc, 4, 2, 1)
        
        self.feat1 = nn.Conv2d(midc, midc, 3, 1, 1)
        self.feat2 = nn.Conv2d(midc*2, midc*2, 3, 1, 1)
        
    def forward(self, x):
        encoder1 = self.ecode1(x)
        encoder2 = self.ecode2(encoder1)
        encoder3 = self.ecode3(encoder2)
        decoder3 = self.upsample1(encoder3)
        decoder2 = self.dcode2(decoder3 + self.feat2(encoder2))
        decoder1 = self.upsample2(decoder2)
        output   = self.dcode1(decoder1 + self.feat1(encoder1))
        
        return output
             
        
def main():
    model = Net(3, 3, 32).cuda().eval()
    
    inputs = torch.randn(4, 3, 128, 128).cuda()
    with torch.no_grad():
        output = model(inputs)
    print(output.size())
    
    
if __name__ == "__main__":
    main()
复制代码
相关文章
相关标签/搜索