医学图像语义分割最佳方法的全面比较:UNet和UNet++

点击上方“AI公园”,关注公众号,选择加“星标“或“置顶”git


做者:Sergey Kolchenkoweb

编译:ronghuaiyang
微信

导读

在不一样的任务上对比了UNet和UNet++以及使用不一样的预训练编码器的效果。网络

介绍

语义分割是计算机视觉的一个问题,咱们的任务是使用图像做为输入,为图像中的每一个像素分配一个类。在语义分割的状况下,咱们不关心是否有同一个类的多个实例(对象),咱们只是用它们的类别来标记它们。有多种关于不一样计算机视觉问题的介绍课程,但用一张图片能够总结不一样的计算机视觉问题:架构

语义分割在生物医学图像分析中有着普遍的应用:x射线、MRI扫描、数字病理、显微镜、内窥镜等。https://grand-challenge.org/challenges上有许多不一样的有趣和重要的问题有待探索。
app

从技术角度来看,若是咱们考虑语义分割问题,对于N×M×3(假设咱们有一个RGB图像)的图像,咱们但愿生成对应的映射N×M×k(其中k是类的数量)。有不少架构能够解决这个问题,但在这里我想谈谈两个特定的架构,Unet和Unet++。dom

有许多关于Unet的评论,它如何永远地改变了这个领域。它是一个统一的很是清晰的架构,由一个编码器和一个解码器组成,前者生成图像的表示,后者使用该表示来构建分割。每一个空间分辨率的两个映射链接在一块儿(灰色箭头),所以能够将图像的两种不一样表示组合在一块儿。而且它成功了!编辑器

接下来是使用一个训练好的编码器。考虑图像分类的问题,咱们试图创建一个图像的特征表示,这样不一样的类在该特征空间能够被分开。咱们能够(几乎)使用任何CNN,并将其做为一个编码器,从编码器中获取特征,并将其提供给咱们的解码器。据我所知,Iglovikov & Shvets 使用了VGG11和resnet34分别为Unet解码器以生成更好的特征和提升其性能。
函数

TernausNet (VGG11 Unet)

Unet++是最近对Unet体系结构的改进,它有多个跳跃链接。性能

根据论文, Unet++的表现彷佛优于原来的Unet。就像在Unet中同样,这里可使用多个编码器(骨干)来为输入图像生成强特征。

我应该使用哪一个编码器?

这里我想重点介绍Unet和Unet++,并比较它们使用不一样的预训练编码器的性能。为此,我选择使用胸部x光数据集来分割肺部。这是一个二值分割,因此咱们应该给每一个像素分配一个类为“1”的几率,而后咱们能够二值化来制做一个掩码。首先,让咱们看看数据。

来自胸片X光数据集的标注数据的例子

这些是很是大的图像,一般是2000×2000像素,有很大的mask,从视觉上看,找到肺不是问题。使用segmentation_models_pytorch库,咱们为Unet和Unet++使用100+个不一样的预训练编码器。咱们作了一个快速的pipeline来训练模型,使用Catalyst (pytorch的另外一个库,这能够帮助你训练模型,而没必要编写不少无聊的代码)和Albumentations(帮助你应用不一样的图像转换)。

  1. 定义数据集和加强。咱们将调整图像大小为256×256,并对训练数据集应用一些大的加强。
import albumentations as A
from torch.utils.data import Dataset, DataLoader
from collections import OrderedDict

class ChestXRayDataset(Dataset):
    def __init__(
        self,
        images,
        masks,
            transforms)
:

        self.images = images
        self.masks = masks
        self.transforms = transforms

    def __len__(self):
        return(len(self.images))

    def __getitem__(self, idx):
        """Will load the mask, get random coordinates around/with the mask,
        load the image by coordinates
        """

        sample_image = imread(self.images[idx])
        if len(sample_image.shape) == 3:
            sample_image = sample_image[..., 0]
        sample_image = np.expand_dims(sample_image, 2) / 255
        sample_mask = imread(self.masks[idx]) / 255
        if len(sample_mask.shape) == 3:
            sample_mask = sample_mask[..., 0]  
        augmented = self.transforms(image=sample_image, mask=sample_mask)
        sample_image = augmented['image']
        sample_mask = augmented['mask']
        sample_image = sample_image.transpose(201)  # channels first
        sample_mask = np.expand_dims(sample_mask, 0)
        data = {'features': torch.from_numpy(sample_image.copy()).float(),
                'mask': torch.from_numpy(sample_mask.copy()).float()}
        return(data)
    
def get_valid_transforms(crop_size=256):
    return A.Compose(
        [
            A.Resize(crop_size, crop_size),
        ],
        p=1.0)

def light_training_transforms(crop_size=256):
    return A.Compose([
        A.RandomResizedCrop(height=crop_size, width=crop_size),
        A.OneOf(
            [
                A.Transpose(),
                A.VerticalFlip(),
                A.HorizontalFlip(),
                A.RandomRotate90(),
                A.NoOp()
            ], p=1.0),
    ])

def medium_training_transforms(crop_size=256):
    return A.Compose([
        A.RandomResizedCrop(height=crop_size, width=crop_size),
        A.OneOf(
            [
                A.Transpose(),
                A.VerticalFlip(),
                A.HorizontalFlip(),
                A.RandomRotate90(),
                A.NoOp()
            ], p=1.0),
        A.OneOf(
            [
                A.CoarseDropout(max_holes=16, max_height=16, max_width=16),
                A.NoOp()
            ], p=1.0),
    ])


def heavy_training_transforms(crop_size=256):
    return A.Compose([
        A.RandomResizedCrop(height=crop_size, width=crop_size),
        A.OneOf(
            [
                A.Transpose(),
                A.VerticalFlip(),
                A.HorizontalFlip(),
                A.RandomRotate90(),
                A.NoOp()
            ], p=1.0),
        A.ShiftScaleRotate(p=0.75),
        A.OneOf(
            [
                A.CoarseDropout(max_holes=16, max_height=16, max_width=16),
                A.NoOp()
            ], p=1.0),
    ])

def get_training_trasnforms(transforms_type):
    if transforms_type == 'light':
        return(light_training_transforms())
    elif transforms_type == 'medium':
        return(medium_training_transforms())
    elif transforms_type == 'heavy':
        return(heavy_training_transforms())
    else:
        raise NotImplementedError("Not implemented transformation configuration")
  1. 定义模型和损失函数。这里咱们使用带有regnety_004编码器的Unet++,并使用RAdam + Lookahed优化器使用DICE + BCE损失之和进行训练。
import torch
import segmentation_models_pytorch as smp
import numpy as np
import matplotlib.pyplot as plt
from catalyst import dl, metrics, core, contrib, utils
import torch.nn as nn
from skimage.io import imread
import os
from sklearn.model_selection import train_test_split
from catalyst.dl import  CriterionCallback, MetricAggregationCallback
encoder = 'timm-regnety_004'
model = smp.UnetPlusPlus(encoder, classes=1, in_channels=1)
#model.cuda()
learning_rate = 5e-3
encoder_learning_rate = 5e-3 / 10
layerwise_params = {"encoder*": dict(lr=encoder_learning_rate, weight_decay=0.00003)}
model_params = utils.process_model_params(model, layerwise_params=layerwise_params)
base_optimizer = contrib.nn.RAdam(model_params, lr=learning_rate, weight_decay=0.0003)
optimizer = contrib.nn.Lookahead(base_optimizer)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.25, patience=10)
criterion = {
    "dice": DiceLoss(mode='binary'),
    "bce": nn.BCEWithLogitsLoss()
}
  1. 定义回调函数并训练!
callbacks = [
    # Each criterion is calculated separately.
    CriterionCallback(
       input_key="mask",
        prefix="loss_dice",
        criterion_key="dice"
    ),
    CriterionCallback(
        input_key="mask",
        prefix="loss_bce",
        criterion_key="bce"
    ),

    # And only then we aggregate everything into one loss.
    MetricAggregationCallback(
        prefix="loss",
        mode="weighted_sum"
        metrics={
            "loss_dice"1.0
            "loss_bce"0.8
        },
    ),

    # metrics
    IoUMetricsCallback(
        mode='binary'
        input_key='mask'
    )
    
]

runner = dl.SupervisedRunner(input_key="features", input_target_key="mask")
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    loaders=loaders,
    callbacks=callbacks,
    logdir='../logs/xray_test_log',
    num_epochs=100,
    main_metric="loss",
    minimize_metric=True,
    verbose=True,
)

若是咱们用不一样的编码器对Unet和Unet++进行验证,咱们能够看到每一个训练模型的验证质量,并总结以下:

Unet和Unet++验证集分数

咱们注意到的第一件事是,在全部编码器中,Unet++的性能彷佛都比Unet好。固然,有时这种差别并非很大,咱们不能说它们在统计上是否彻底不一样 —— 咱们须要在多个folds上训练,看看分数分布,单点不能证实任何事情。第二,resnest200e显示了最高的质量,同时仍然有合理的参数数量。有趣的是,若是咱们看看https://paperswithcode.com/task/semantic-segmentation,咱们会发现resnest200在一些基准测试中也是SOTA。

好的,可是让咱们用Unet++和Unet使用resnest200e编码器来比较不一样的预测。

Unet和Unet++使用resnest200e编码器的预测。左图显示了两种模型的预测差别

在某些个别状况下,Unet++实际上比Unet更糟糕。但总的来讲彷佛更好一些。

通常来讲,对于分割网络来讲,这个数据集看起来是一个容易的任务。让咱们在一个更难的任务上测试Unet++。为此,我使用PanNuke数据集,这是一个带标注的组织学数据集(205,343个标记核,19种不一样的组织类型,5个核类)。数据已经被分割成3个folds。

PanNuke样本的例子

咱们可使用相似的代码在这个数据集上训练Unet++模型,以下所示:

验证集上的Unet++得分

咱们在这里看到了相同的模式 - resnest200e编码器彷佛比其余的性能更好。咱们能够用两个不一样的模型(最好的是resnest200e编码器,最差的是regnety_002)来可视化一些例子。

resnest200e和regnety_002的预测

咱们能够确定地说,这个数据集是一项更难的任务 —— 不只mask不够精确,并且个别的核被分配到错误的类别。然而,使用resnest200e编码器的Unet++仍然表现很好。

总结

这不是一个全面语义分割的指导,这更多的是一个想法,使用什么来得到一个坚实的基线。有不少模型、FPN,DeepLabV3, Linknet与Unet有很大的不一样,有许多Unet-like架构,例如,使用双编码器的Unet,MAnet,PraNet,U²-net — 有不少的型号供你选择,其中一些可能在你的任务上表现的比较好,可是,一个坚实的基线能够帮助你从正确的方向上开始。


END

英文原文:https://towardsdatascience.com/the-best-approach-to-semantic-segmentation-of-biomedical-images-bbe4fd78733f

请长按或扫描二维码关注本公众号


喜欢的话,请给我个在看吧


本文分享自微信公众号 - AI公园(AI_Paradise)。
若有侵权,请联系 support@oschina.cn 删除。
本文参与“OSC源创计划”,欢迎正在阅读的你也加入,一块儿分享。

相关文章
相关标签/搜索