今天来介绍一个经典的语义分割网络U-net, 它于2015年提出,最初应用在医疗影像分割任务上,因为效果很好,以后被普遍应用在各类分割任务中。至今已衍生出许多基于U-net的分割模型。
U-net是典型的Encoder-Decoder结构,encoder进行特征提取,decoder
进行上采样。因为数据的限制,U-net在训练阶段使用了大量的数据加强操做,最后获得了不错的效果。网络
U-net的网络结构以下所示。左边为encoder部分,对输入进行下采样,下采样经过最大池化实现;右边为decoder部分,对encoder的输出进行上采样,恢复分辨率,上采样经过Upsample实现;中间为跳跃链接(Skip-connect),进行特征融合。因为整个网络形似一个"U",因此称为U-net。
网络中除了最后的输出层,其他全部卷积层均为3 * 3卷积。ide
import torch as t import torch.nn as nn class DoubleConv(nn.Module): def __init__(self, in_channels, out_channels): super(DoubleConv, self).__init__() self.dconv = nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, 1, 1), nn.BatchNorm2d(out_channels), # inplace设为True能够节省显存/内存 nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, img): return self.dconv(img) # 下采样 class Down(nn.Module): def __init__(self, in_channels, out_channels): super(Down, self).__init__() self.down = nn.Sequential( nn.MaxPool2d(2, 2), DoubleConv(in_channels, out_channels) ) def forward(self, img): return self.down(img) # 上采样 class Up(nn.Module): def __init__(self, in_channels, out_channels, bilinear=True): super(Up, self).__init__() # ConvTranspose2D 有可学习的参数, 会在训练过程当中不断调整参数。会增长模型的复杂度,可能会形成过拟合 # Upsample 没有可学习的参数 # 和Conv2d和MaxPooling2d的区别同样 if bilinear: self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) else: self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2) self.conv = DoubleConv(in_channels, out_channels) def forward(self, x1, x2): x1 = self.up(x1) # pading 保证x1和x2的大小同样 dx = x2.shape[3] - x1.shape[3] dy = x2.shape[2] - x1.shape[2] x1 = nn.functional.pad(x1, [dx // 2, dx - dx // 2, dy // 2, dy - dy // 2]) # 通道合并 x = t.cat([x1, x2], dim=1) return self.conv(x) # 主网络 class CrackUnet(nn.Module): def __init__(self, channels, classes, bilinear=True): super(CrackUnet, self).__init__() self.channels = channels self.classes = classes self.bilinear = bilinear # self.inconv = DoubleConv(self.channels, 64) # 4个下采样层 self.down1 = Down(64, 128) self.down2 = Down(128, 256) self.down3 = Down(256, 512) self.down4 = Down(512, 512) # 4个上采样层, 采用双线性采样 self.up1 = Up(1024, 256, bilinear) self.up2 = Up(512, 128, bilinear) self.up3 = Up(256, 64, bilinear) self.up4 = Up(128, 64, bilinear) self.outconv = nn.Conv2d(64, channels, 1) def forward(self, img): img = self.inconv(img) down1 = self.down1(img) down2 = self.down2(down1) down3 = self.down3(down2) down4 = self.down4(down3) x = self.up1(down4, down3) del down4 del down3 x = self.up2(x, down2) del down2 x = self.up3(x, down1) del down1 x = self.up5(x, img) del img return self.outconv(x)
U-net结构简单稳定,是典型的下采样+上采样的分割网络结构。尤为在数据集较小的时候,推荐使用。学习