crf(条件随机场)用于遥感影像分类结果的优化

参考连接:1.https://blog.csdn.net/wzw12315/article/details/106475791
2。https://www.cnblogs.com/wanghui-garcia/p/10761612.html
主要的代码段都是差很少的,就是用gdal读入了数据,结果仍是有点变化的,参数须要本身慢慢调整html

""" Adapted from the inference.py to demonstate the usage of the util functions. """
import sys
import numpy as np
import pydensecrf.densecrf as dcrf
import cv2
import gdal
from skimage import color
# Get im{read,write} from somewhere.
# try:
    # from cv2 import imread, imwrite
# except ImportError:
    # # Note that, sadly, skimage unconditionally import scipy and matplotlib,
    # # so you'll need them if you don't have OpenCV. But you probably have them.
    # from skimage.io import imread, imsave
    # imwrite = imsave
    # TODO: Use scipy instead.


from skimage.io import imread, imsave
imwrite = imsave
from pydensecrf.utils import unary_from_labels, create_pairwise_bilateral, create_pairwise_gaussian

def read_img(filename):     #读图
    dataset=gdal.Open(filename)

    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize

    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)

    del dataset 
    return im_proj,im_geotrans,im_width, im_height,im_data


def write_img(filename, im_proj, im_geotrans, im_data): #写出图
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)

    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset

def crf(inimage,img_anno):    # inimage为原图 img_anno为预测结果,个人预测结果是0,1,2,3这样,每一个数字表明一个类别
        fn_im = inimage
        fn_anno = img_anno
        img = inimage
        anno_rgb = img_anno
        rgb = anno_rgb
        print("=========>>", anno_rgb.shape)
        #rgb= np.argmax(anno_rgb[0],axis=0)
        print("=======>>",rgb.shape)
        print(np.max(rgb), np.min(rgb))
        anno_lbl=rgb
        # img = img[0]
        # img = img.transpose(1, 2, 0)
        colors, labels = np.unique(anno_lbl, return_inverse=True)
        colors = colors[1:]
        colorize = np.empty((len(colors), 3), np.uint8)
        colorize[:,0] = (colors & 0x0000FF)
        colorize[:,1] = (colors & 0x00FF00) >> 8
        colorize[:,2] = (colors & 0xFF0000) >> 16
        # n_labels = len(set(labels.flat))-1
        n_labels = len(set(labels.flat))   #这里我把减1去掉了,由于个人全部数字都表明一个类别,没有背景
        if n_labels <= 1:
            return rgb
        use_2d = False
        if use_2d:
            img = img.astype(int)
            d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], n_labels)
            U = unary_from_labels(labels, n_labels, gt_prob=0.7, zero_unsure=False)
            d.setUnaryEnergy(U)
            d.addPairwiseGaussian(sxy=(3, 3), compat=3, kernel=dcrf.DIAG_KERNEL,    #1.CONST_KERNEL 2.DIAG_KERNEL (the default) 3.FULL_KERNEL
                                normalization=dcrf.NORMALIZE_SYMMETRIC)  #1.NO_NORMALIZATION 2.NORMALIZE_BEFORE 3.NORMALIZE_AFTER 4.NORMALIZE_SYMMETRIC (the default)
            img = counts = np.copy(np.array(img,dtype = np.uint8),order='C')
            d.addPairwiseBilateral(sxy=(80,80), srgb=(13, 13, 13), rgbim=img,
                                compat=10,
                                kernel=dcrf.CONST_KERNEL,
                                normalization=dcrf.NORMALIZE_SYMMETRIC)

        else:
			#这部分比上面的效果好点,建议用这个
            # Example using the DenseCRF class and the util functions
            d = dcrf.DenseCRF(img.shape[1] * img.shape[0], n_labels)

            # get unary potentials (neg log probability)
            U = unary_from_labels(labels, n_labels, gt_prob=0.7, zero_unsure=False)  #zero_unsure=False 0不是背景而是一个类别,因此False
            d.setUnaryEnergy(U)

            # This creates the color-independent features and then add them to the CRF
            feats = create_pairwise_gaussian(sdims=(3, 3), shape=img.shape[:2])
            d.addPairwiseEnergy(feats, compat=3,
                                kernel=dcrf.DIAG_KERNEL,
                                normalization=dcrf.NORMALIZE_SYMMETRIC)

            # This creates the color-dependent features and then add them to the CRF
            feats = create_pairwise_bilateral(sdims=(80, 80), schan=(13, 13, 13),
                                            img=img, chdim=2)
            d.addPairwiseEnergy(feats, compat=10,
                                kernel=dcrf.DIAG_KERNEL,
                                normalization=dcrf.NORMALIZE_SYMMETRIC)

        Q = d.inference(20)


# Find out the most probable class for each pixel.
        MAP = np.argmax(Q, axis=0)

        return MAP.reshape(img.shape[:2])

if __name__ == "__main__":
    img_path = "D:/xx/xx/xx.tif"
    anno = 'D:/xx/result/t.tif'
    out = 'D:/xx/result/t_t.tif'

    im_proj,im_geotrans,im_width, im_height,im_data = read_img(img_path)
    img = im_data.transpose(1,2,0)
    # print(img.shape)

    im_proj,im_geotrans,im_width, im_height,lab = read_img(anno)

    # dense_crf(img, lab, out, im_proj,im_geotrans)
    arr = crf(img,lab)
    write_img(out, im_proj, im_geotrans, arr)

原图
处理前的结果:
预测结果
处理后的结果:
仔细看变化仍是挺大的,去掉了不少杂质,让类别分布更纯粹
在这里插入图片描述python

其余版本:
1.web

import os
import gdal
import numpy as np
import pydensecrf.densecrf as dcrf
from pydensecrf.utils import compute_unary, create_pairwise_bilateral,create_pairwise_gaussian, softmax_to_unary, unary_from_softmax,unary_from_labels

# """ 
# Getting a Unary
# 获得 unary potentials有两种常见的方法:
# 1)由人类或其余过程产生的硬标签。该方法由from pydensecrf.utils import unary_from_labels实现
# 2)由几率分布计算获得,例如深度网络的softmax输出。即咱们以前先对图片使用训练好的网络预测获得最终通过softmax函数获得的分类结果,
# 这里须要将这个结果转成一元势
# """

def read_img(filename):
    dataset=gdal.Open(filename)

    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize

    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)

    del dataset 
    return im_proj,im_geotrans,im_width, im_height,im_data


def write_img(filename, im_proj, im_geotrans, im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)

    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset


def dense_crf(img, pre, save, im_proj, im_geotrans):

    softmax = pre  # processed_probabilities:CNN 预测几率 通过 softmax [n_label,H,W]
    # print(softmax.shape)
    # exit()
    #1)Getting a Unary
    #1.直接调用函数
    arr = np.zeros((4, img.shape[0], img.shape[1]))
    arr[0] = pre
    arr[1] = pre
    arr[2] = pre
    arr[3] = pre
    # print(arr.shape)
    # softmax = arr
    # softmax[softmax==0] = 4
    # print(tt)
    # unary = unary_from_softmax(softmax)
    # softmax = softmax.astype(np.uint32)
    # print(unary.shape)

    # unary = unary_from_labels(softmax, 4, gt_prob=0.7, zero_unsure=0)
    # print(unary.shape)

    # unary = softmax.reshape(4, -1)
    # unary = unary.astype(np.float32)
    # print(unary)

    #2.本身生成一元势函数
    # The inputs should be C-continious -- we are using Cython wrapper
    unary = -np.log(arr)
    unary = unary.reshape((4, -1))
    unary = np.ascontiguousarray(unary)  # (21, n)

    d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], 4)  # h,w,n_class

    unary = np.float32(unary)
    d.setUnaryEnergy(unary)

    # This potential penalizes small pieces of segmentation that are
    # spatially isolated -- enforces more spatially consistent segmentations
    
    # Pairwise potentials(二元势)
    feats = create_pairwise_gaussian(sdims=(3, 3), shape=img.shape[:2])

    d.addPairwiseEnergy(feats, compat=3,
                        kernel=dcrf.DIAG_KERNEL,
                        normalization=dcrf.NORMALIZE_SYMMETRIC)

    # This creates the color-dependent features --
    # because the segmentation that we get from CNN are too coarse
    # and we can use local color features to refine them
    feats = create_pairwise_bilateral(sdims=(80, 80), schan=(13, 13, 13),
                                      img=img, chdim=2)

    d.addPairwiseEnergy(feats, compat=10,
                        kernel=dcrf.DIAG_KERNEL,
                        normalization=dcrf.NORMALIZE_SYMMETRIC)
    # 快捷方法
    # d.addPairwiseGaussian(sxy=3, compat=3)
    # d.addPairwiseBilateral(sxy=80, srgb=13, rgbim=img, compat=10)
    # 迭代次数,对于IMG_1702(2592*1456)这张图,迭代5 16.807087183s 迭代20 37.5700438023s
    Q = d.inference(5)
    print(Q)
    res = np.argmax(Q, axis=0).reshape((img.shape[0], img.shape[1]))
    res = res*255

    write_img(save, im_proj, im_geotrans, res)

    return res


if __name__ == "__main__":
    img_path = "D:/xx/xx/xx.tif"
    anno = 'D:/xx/result/t.tif'
    out = 'D:/xx/result/t_t.tif'

    im_proj,im_geotrans,im_width, im_height,im_data = read_img(img_path)
    img = im_data.transpose(1,2,0)

    im_proj,im_geotrans,im_width, im_height,lab = read_img(anno)

    dense_crf(img, lab, out, im_proj,im_geotrans)
import os, sys
import numpy as np
import pydensecrf.densecrf as dcrf
import cv2, gdal
from collections import Counter

# Get im{read,write} from somewhere.
try:
    from cv2 import imread, imwrite
except ImportError:
    # Note that, sadly, skimage unconditionally import scipy and matplotlib,
    # so you'll need them if you don't have OpenCV. But you probably have them.
    from skimage.io import imread, imsave
    imwrite = imsave
    # TODO: Use scipy instead.

from pydensecrf.utils import unary_from_labels, create_pairwise_bilateral, create_pairwise_gaussian

def read_img(filename):
    dataset=gdal.Open(filename)

    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize

    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)

    del dataset 
    return im_proj,im_geotrans,im_width, im_height,im_data


def write_img(filename, im_proj, im_geotrans, im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)

    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset


def crf(x,y,z):
    # fn_im = 'unet_pred/%s'%x
    # fn_anno = 'mask/%s'%y
    # fn_output = 'crf/%s'%z

    fn_im = x
    fn_anno = y
    fn_output = z

    ##################################
    ### Read images and annotation ###
    ##################################
    # img = imread(fn_im)
    im_proj,im_geotrans,im_width, im_height,im_data = read_img(fn_im)
    img = im_data.transpose(1,2,0)

    # Convert the annotation's RGB color to a single 32-bit integer color 0xBBGGRR
    
    anno_rgb = imread(fn_anno)
    anno_rgb[anno_rgb == 0] = 4
    anno_rgb = anno_rgb.astype(np.uint32)

    #anno_rgb = anno_rgb.astype(np.uint32)
    # anno_rgb[anno_rgb < 1] = 1
    # anno_rgb[anno_rgb > 1] = 255

    anno_lbl = anno_rgb[:, :, 0] + (anno_rgb[:, :, 1] << 8) + (anno_rgb[:, :, 2] << 16)

    # labels = labels_cc
    # Convert the 32bit integer color to 1, 2, ... labels.
    # Note that all-black, i.e. the value 0 for background will stay 0.
    colors, labels = np.unique(anno_lbl, return_inverse=True)
    labels[labels==0] = 4

    # But remove the all-0 black, that won't exist in the MAP!
    HAS_UNK = 0 in colors
    if HAS_UNK:
        print(
        "Found a full-black pixel in annotation image, assuming it means 'unknown' label, and will thus not be present in the output!")
        print(
        "If 0 is an actual label for you, consider writing your own code, or simply giving your labels only non-zero values.")
        colors = colors[1:]
    # else:
    # print("No single full-black pixel found in annotation image. Assuming there's no 'unknown' label!")

    # And create a mapping back from the labels to 32bit integer colors.
    colorize = np.empty((len(colors), 3), np.uint8)
    colorize[:, 0] = (colors & 0x0000FF)
    colorize[:, 1] = (colors & 0x00FF00) >> 8
    colorize[:, 2] = (colors & 0xFF0000) >> 16

    # Compute the number of classes in the label image.
    # We subtract one because the number shouldn't include the value 0 which stands
    # for "unknown" or "unsure".
    n_labels = len(set(labels.flat)) - int(HAS_UNK)
    print(n_labels, " labels", (" plus \"unknown\" 0: " if HAS_UNK else ""), set(labels.flat))

    ###########################
    ### Setup the CRF model ###
    ###########################

    use_2d = False
    # use_2d = True
    if use_2d:
        print("Using 2D specialized functions")

        # Example using the DenseCRF2D code
        d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], n_labels)

        # get unary potentials (neg log probability)
        U = unary_from_labels(labels, n_labels, gt_prob=0.7, zero_unsure=HAS_UNK)
        d.setUnaryEnergy(U)

        # This adds the color-independent term, features are the locations only.
        d.addPairwiseGaussian(sxy=(3, 3), compat=3, kernel=dcrf.DIAG_KERNEL,
                              normalization=dcrf.NORMALIZE_SYMMETRIC)

        # This adds the color-dependent term, i.e. features are (x,y,r,g,b).
        d.addPairwiseBilateral(sxy=(80, 80), srgb=(13, 13, 13), rgbim=img,
                               compat=10,
                               kernel=dcrf.DIAG_KERNEL,
                               normalization=dcrf.NORMALIZE_SYMMETRIC)
    else:
        print("Using generic 2D functions")

        # Example using the DenseCRF class and the util functions
        d = dcrf.DenseCRF(img.shape[1] * img.shape[0], n_labels)

        # get unary potentials (neg log probability)
        U = unary_from_labels(labels, n_labels, gt_prob=0.7, zero_unsure=HAS_UNK)
        d.setUnaryEnergy(U)

        # This creates the color-independent features and then add them to the CRF
        feats = create_pairwise_gaussian(sdims=(3, 3), shape=img.shape[:2])
        d.addPairwiseEnergy(feats, compat=3,
                            kernel=dcrf.DIAG_KERNEL,
                            normalization=dcrf.NORMALIZE_SYMMETRIC)

        # This creates the color-dependent features and then add them to the CRF
        feats = create_pairwise_bilateral(sdims=(80, 80), schan=(13, 13, 13),
                                          img=img, chdim=2)
        d.addPairwiseEnergy(feats, compat=10,
                            kernel=dcrf.DIAG_KERNEL,
                            normalization=dcrf.NORMALIZE_SYMMETRIC)

    ####################################
    ### Do inference and compute MAP ###
    ####################################

    # Run five inference steps.
    Q = d.inference(5)

    # Find out the most probable class for each pixel.
    MAP = np.argmax(Q, axis=0)

    # Convert the MAP (labels) back to the corresponding colors and save the image.
    # Note that there is no "unknown" here anymore, no matter what we had at first.
    MAP = colorize[MAP, :]
    re_out = MAP.reshape(img.shape)
    imwrite(fn_output, re_out[:,:,0])

    # Just randomly manually run inference iterations
    Q, tmp1, tmp2 = d.startInference()
    for i in range(5):
        print("KL-divergence at {}: {}".format(i, d.klDivergence(Q)))
    d.stepInference(Q, tmp1, tmp2)

    print(np.shape(Q), np.shape(MAP), np.shape(tmp2))


if __name__ == "__main__":
    img_path = "D:/xx/xx.tif"
    anno = 'D:/xx/temp/class_raster.tif'
    out = 'D:/xx/temp/class_raster_crf.tif'
    crf(img_path,anno,out)

    
    # img_path = ''
    # pre_path = ''
    # out_path = ''
    # img_names = os.listdir(img_path)
    # for name in img_names:
    # im_full_path = os.path.join(img_path, name)
    # pre_full_path = os.path.join(pre_path, name)
    # out_full_path = os.path.join(out_path, name)
    # crf(im_full_path,pre_full_path,out_full_path)