飞桨PaddleColorization-黑白照片着色

PaddleColorization-黑白照片着色python

将黑白照片着色是否是一件神奇的事情?数据库

本项目将带领你一步一步学习将黑白图片甚至黑白影片的彩色化bash

下载安装命令

## CPU版本安装命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/cpu paddlepaddle

## GPU版本安装命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/gpu paddlepaddle-gpu

 

黑白照片着色

咱们都知道,有不少经典的老照片,受限于那个时代的技术,只能以黑白的形式传世。尽管黑白照片别有一番风味,可是彩色照片有时候能给人更强的代入感。本项目经过通俗易懂的方式简单实现黑白照片着色并对部分照片取得不错的着色效果。黑白照片着色是计算机视觉领域经典的问题。近年来随着卷积神经网络(CNN)的普遍应用,经过CNN为黑白照片着色成为新颖且可行的方向。本项目承载于百度的学习与实训社区AIStudio,总体实现采用ResNet残差网络为主干网络并设计复合损失函数进行网络训练。网络

开启着色之旅!!!

先来看当作品

欢迎你们fork学习~有任何问题欢迎在评论区留言互相交流哦多线程

这里一点小小的宣传,我感兴趣的领域包括迁移学习、生成对抗网络。欢迎交流关注。来AI Studio互粉吧等你哦 app

#安装所需的依赖库 
!pip install sklearn scikit-image

1 项目简介

本项目基于paddlepaddle,结合残差网络(ResNet),经过监督学习的方式,训练模型将黑白图片转换为彩色图片ide


1.1 残差网络(ResNet)

1.1.1 背景介绍

ResNet(Residual Network) [15] 是2015年ImageNet图像分类、图像物体定位和图像物体检测比赛的冠军。针对随着网络训练加深致使准确度降低的问题,ResNet提出了残差学习方法来减轻训练深层网络的困难。在已有设计思路(BN, 小卷积核,全卷积网络)的基础上,引入了残差模块。每一个残差模块包含两条路径,其中一条路径是输入特征的直连通路,另外一条路径对该特征作两到三次卷积操做获得该特征的残差,最后再将两条路径上的特征相加。函数

残差模块如图9所示,左边是基本模块链接方式,由两个输出通道数相同的3x3卷积组成。右边是瓶颈模块(Bottleneck)链接方式,之因此称为瓶颈,是由于上面的1x1卷积用来降维(图示例即256->64),下面的1x1卷积用来升维(图示例即64->256),这样中间3x3卷积的输入和输出通道数都较小(图示例即64->64)。工具

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4gc3N2O9-1607403578575)(https://ai-studio-static-online.cdn.bcebos.com/7ede3132804549228b5c4a729d90e6b25821272dd9e74e41a95d3363f9e06c0e)]oop

1.2 项目设计思路及主要解决问题

  • 设计思路:经过训练网络对大量样本的学习获得经验分布(例如天空永远是蓝色的,草永远是绿色的),经过经验分布推得黑白图像上各部分合理的颜色
  • 主要解决问题:大量物体颜色并非固定的也就是物体颜色具备多模态性(例如:苹果能够是红色也能够是绿色和黄色)。一般使用均方差做为损失函数会让具备颜色多模态属性的物体趋于寻找一个“平均”的颜色(一般为淡黄色)致使着色后的图片饱和度不高。

1.3 本文主要特征

  • 将Adam优化器beta1参数设置为0.8,具体请参考原论文
  • 将BatchNorm批归一化中momentum参数设置为0.5
  • 采用基本模块链接方式
  • 为抑制多模态问题,在均方差的基础上从新设计损失函数

损失函数公式以下:

O u t = 1 / n ∑ ( i n p u t − l a b e l ) 2 + 16.7 / ( n ∑ ( i n p u t − i n p u t ˉ ) 2 ) Out = 1/n\sum{(input-label)^{2}} + 16.7/(n{\sum{(input - \bar{input})^{2}}}) Out=1/n(inputlabel)2+16.7/(n(inputinputˉ)2)


1.4 数据集介绍(ImageNet)

ImageNet项目是一个用于视觉对象识别软件研究的大型可视化数据库。超过1400万的图像URL被ImageNet手动注释,以指示图片中的对象;在至少一百万个图像中,还提供了边界框。ImageNet包含2万多个类别; [2]一个典型的类别,如“气球”或“草莓”,包含数百个图像。第三方图像URL的注释数据库能够直接从ImageNet免费得到;可是,实际的图像不属于ImageNet。自2010年以来,ImageNet项目每一年举办一次软件比赛,即ImageNet大规模视觉识别挑战赛(ILSVRC),软件程序竞相正确分类检测物体和场景。 ImageNet挑战使用了一个“修剪”的1000个非重叠类的列表。2012年在解决ImageNet挑战方面取得了巨大的突破,被普遍认为是2010年的深度学习革命的开始。(来源:百度百科)

ImageNet2012介绍:

  • Training images (Task 1 & 2). 138GB.(约120万张高清图片,共1000个类别)
  • Validation images (all tasks). 6.3GB.
  • Training bounding box annotations (Task 1 & 2 only). 20MB.

1.5 LAB颜色空间

Lab模式是根据Commission International Eclairage(CIE)在1931年所制定的一种测定颜色的国际标准创建的。于1976年被改进,而且命名的一种色彩模式。Lab颜色模型弥补了RGB和CMYK两种色彩模式的不足。它是一种设备无关的颜色模型,也是一种基于生理特征的颜色模型。 [1] Lab颜色模型由三个要素组成,一个要素是亮度(L),a 和b是两个颜色通道。a包括的颜色是从深绿色(低亮度值)到灰色(中亮度值)再到亮粉红色(高亮度值);b是从亮蓝色(低亮度值)到灰色(中亮度值)再到黄色(高亮度值)。所以,这种颜色混合后将产生具备明亮效果的色彩。(来源:百度百科)

2.使用Shell命令对数据集进行初步处理(运行时间:约20min)

tar xf data/data9244/ILSVRC2012_img_val.tar -C work/test/ 
cd ./work/train/;ls ../data/tar/*.tar | xargs -n1 tar xf 
#显示work/train中图片数量 
find work/train -type f | wc -l

mkdir: cannot create directory ‘work/train’: File exists
mkdir: cannot create directory ‘work/test’: File exists

3.预处理

3.1预处理-采用多线程对训练集中单通道图删除(运行时间:约20min)

import os 
import imghdr 
import numpy as np 
from PIL import Image 
import threading 

'''多线程将数据集中单通道图删除''' 
def cutArray(l, num): 
avg = len(l) / float(num) 
o = [] 
last = 0.0 

while last < len(l): 
o.append(l[int(last):int(last + avg)]) 
last += avg 

return o 

def deleteErrorImage(path,image_dir): 
count = 0 
for file in image_dir: 
try: 
image = os.path.join(path,file) 
image_type = imghdr.what(image) 
if image_type is not 'jpeg': 
os.remove(image) 
count = count + 1 

img = np.array(Image.open(image)) 
if len(img.shape) is 2: 
os.remove(image) 
count = count + 1 
except Exception as e: 
print(e) 
print('done!') 
print('已删除数量:' + str(count)) 

class thread(threading.Thread): 
def __init__(self, threadID, path, files): 
threading.Thread.__init__(self) 
self.threadID = threadID 
self.path = path 
self.files = files 
def run(self): 
deleteErrorImage(self.path,self.files) 

if __name__ == '__main__': 
path = './work/train/' 
files = os.listdir(path) 
files = cutArray(files,8) 
t1 = threading.Thread(target=deleteErrorImage,args=(path,files[0])) 
t2 = threading.Thread(target=deleteErrorImage,args=(path,files[1])) 
t3 = threading.Thread(target=deleteErrorImage,args=(path,files[2])) 
t4 = threading.Thread(target=deleteErrorImage,args=(path,files[3])) 
t5 = threading.Thread(target=deleteErrorImage,args=(path,files[4])) 
t6 = threading.Thread(target=deleteErrorImage,args=(path,files[5])) 
t7 = threading.Thread(target=deleteErrorImage,args=(path,files[6])) 
t8 = threading.Thread(target=deleteErrorImage,args=(path,files[7])) 
threadList = [] 
threadList.append(t1) 
threadList.append(t2) 
threadList.append(t3) 
threadList.append(t4) 
threadList.append(t5) 
threadList.append(t6) 
threadList.append(t7) 
threadList.append(t8) 
for t in threadList: 
t.setDaemon(True) 
t.start() 
t.join()

done!
已删除数量:470
done!
已删除数量:432
done!
已删除数量:426
done!
已删除数量:483
[Errno 2] No such file or directory: ‘./work/train/n02105855_2933.JPEG’
done!
已删除数量:490
done!
已删除数量:454
done!
已删除数量:467
done!
已删除数量:482

3.2预处理-采用多线程对图片进行缩放后裁切到256*256分辨率(运行时间:约40min)

from PIL import Image 
import os.path 
import os 
import threading 
from PIL import ImageFile 
ImageFile.LOAD_TRUNCATED_IMAGES = True 

'''多线程将图片缩放后再裁切到256*256分辨率''' 
w = 256 
h = 256 

def cutArray(l, num): 
avg = len(l) / float(num) 
o = [] 
last = 0.0 

while last < len(l): 
o.append(l[int(last):int(last + avg)]) 
last += avg 

return o 

def convertjpg(jpgfile,outdir,width=w,height=h): 
img=Image.open(jpgfile) 
(l,h) = img.size 
rate = min(l,h) / width 
try: 
img = img.resize((int(l // rate),int(h // rate)),Image.BILINEAR) 
(l,h) = img.size 
lstart = (l - width)//2 
hstart = (h - height)//2 
img = img.crop((lstart,hstart,lstart + width,hstart + height)) 
img.save(os.path.join(outdir,os.path.basename(jpgfile))) 
except Exception as e: 
print(e) 

class thread(threading.Thread): 
def __init__(self, threadID, inpath, outpath, files): 
threading.Thread.__init__(self) 
self.threadID = threadID 
self.inpath = inpath 
self.outpath = outpath 
self.files = files 
def run(self): 
count = 0 
try: 
for file in self.files: 
convertjpg(self.inpath + file,self.outpath) 
count = count + 1 
except Exception as e: 
print(e) 
print('已处理图片数量:' + str(count)) 

if __name__ == '__main__': 
inpath = './work/train/' 
outpath = './work/train/' 
files = os.listdir(inpath) 
# for file in files: 
# convertjpg(path + file,path) 
files = cutArray(files,8) 
T1 = thread(1, inpath, outpath, files[0]) 
T2 = thread(2, inpath, outpath, files[1]) 
T3 = thread(3, inpath, outpath, files[2]) 
T4 = thread(4, inpath, outpath, files[3]) 
T5 = thread(5, inpath, outpath, files[4]) 
T6 = thread(6, inpath, outpath, files[5]) 
T7 = thread(7, inpath, outpath, files[6]) 
T8 = thread(8, inpath, outpath, files[7]) 

T1.start() 
T2.start() 
T3.start() 
T4.start() 
T5.start() 
T6.start() 
T7.start() 
T8.start() 

T1.join() 
T2.join() 
T3.join() 
T4.join() 
T5.join() 
T6.join() 
T7.join() 
T8.join()

 

已处理图片数量:58782
已处理图片数量:58783
已处理图片数量:58782
已处理图片数量:58782
已处理图片数量:58782
已处理图片数量:58782
已处理图片数量:58782
已处理图片数量:58782

4.导入本项目所需的库

import os 
import cv2 
import numpy as np 
import paddle.dataset as dataset 
from skimage import io,color,transform 
import sklearn.neighbors as neighbors 
import paddle 
import paddle.fluid as fluid 
import numpy as np 
import sys 
import os 
from skimage import io,color 
import matplotlib.pyplot as plt 
import six

5.定义数据预处理工具-DataReader

'''准备数据,定义Reader()''' 

PATH = 'work/train/' 
TEST = 'work/train/' 
Q = np.load('work/Q.npy') 
Weight = np.load('work/Weight.npy') 

class DataGenerater: 
def __init__(self): 
self.datalist = os.listdir(PATH) 
self.testlist = os.listdir(TEST) 
self.datalist = datalist 


def load(self, image): 
'''读取图片,并转为Lab,并提取出L和ab''' 
img = io.imread(image) 
lab = np.array(color.rgb2lab(img)).transpose() 
l = lab[:1,:,:] 
l = l.astype('float32') 
ab = lab[1:,:,:] 
ab = ab.astype('float32') 
return l,ab 

def create_train_reader(self): 
'''给dataset定义reader''' 

def reader(): 
for img in self.datalist: 
#print(img) 
try: 
l, ab = self.load(PATH + img) 
#print(ab) 
yield l.astype('float32'), ab.astype('float32') 
except Exception as e: 
print(e) 

return reader 

def create_test_reader(self,): 
'''给test定义reader''' 

def reader(): 
for img in self.testlist: 
l,ab = self.load(TEST + img) 
yield l.astype('float32'),ab.astype('float32') 

return reader 
def train(batch_sizes = 32): 
reader = DataGenerater().create_train_reader() 
return reader 

def test(): 
reader = DataGenerater().create_test_reader() 
return reader

6.定义网络功能模块并定义网络

本文网络设计采用3组基本残差模块和2组反卷积层组成

import IPython.display as display 
import warnings 
warnings.filterwarnings('ignore') 

Q = np.load('work/Q.npy') 
weight = np.load('work/Weight.npy') 
Params_dirname = "work/model/gray2color.inference.model" 

'''自定义损失函数''' 
def createLoss(predict, truth): 
'''均方差''' 
loss1 = fluid.layers.square_error_cost(predict,truth) 
#loss2 = fluid.layers.square_error_cost(predict,fluid.layers.fill_constant(shape=[BATCH_SIZE,2,256,256],value=fluid.layers.mean(predict),dtype='float32')) 
cost = fluid.layers.mean(loss1) #+ 16.7 / fluid.layers.mean(loss2) 
return cost 

def conv_bn_layer(input, 
ch_out, 
filter_size, 
stride, 
padding, 
act='relu', 
bias_attr=True): 
tmp = fluid.layers.conv2d( 
input=input, 
filter_size=filter_size, 
num_filters=ch_out, 
stride=stride, 
padding=padding, 
act=None, 
bias_attr=bias_attr) 
return fluid.layers.batch_norm(input=tmp,act=act,momentum=0.5) 


def shortcut(input, ch_in, ch_out, stride): 
if ch_in != ch_out: 
return conv_bn_layer(input, ch_out, 1, stride, 0, None) 
else: 
return input 


def basicblock(input, ch_in, ch_out, stride): 
tmp = conv_bn_layer(input, ch_out, 3, stride, 1) 
tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, act=None, bias_attr=True) 
short = shortcut(input, ch_in, ch_out, stride) 
return fluid.layers.elementwise_add(x=tmp, y=short, act='relu') 


def layer_warp(block_func, input, ch_in, ch_out, count, stride): 
tmp = block_func(input, ch_in, ch_out, stride) 
for i in range(1, count): 
tmp = block_func(tmp, ch_out, ch_out, 1) 
return tmp 

###反卷积层 
def deconv(x, num_filters, filter_size=5, stride=2, dilation=1, padding=2, output_size=None, act=None): 
return fluid.layers.conv2d_transpose( 
input=x, 
num_filters=num_filters, 
# 滤波器数量 
output_size=output_size, 
# 输出图片大小 
filter_size=filter_size, 
# 滤波器大小 
stride=stride, 
# 步长 
dilation=dilation, 
# 膨胀比例大小 
padding=padding, 
use_cudnn=True, 
# 是否使用cudnn内核 
act=act 
# 激活函数 
) 
def bn(x, name=None, act=None,momentum=0.5): 
return fluid.layers.batch_norm( 
x, 
bias_attr=None, 
# 指定偏置的属性的对象 
moving_mean_name=name + '3', 
# moving_mean的名称 
moving_variance_name=name + '4', 
# moving_variance的名称 
name=name, 
act=act, 
momentum=momentum, 
) 


def resnetImagenet(input): 
#128 
x = layer_warp(basicblock, input, 64, 128, 1, 2) 
#64 
x = layer_warp(basicblock, x, 128, 256, 1, 2) 
#32 
x = layer_warp(basicblock, x, 256, 512, 1, 2) 
#16 
x = layer_warp(basicblock, x, 512, 1024, 1, 2) 
#8 
x = layer_warp(basicblock, x, 1024, 2048, 1, 2) 
#16 
x = deconv(x, num_filters=1024, filter_size=4, stride=2, padding=1) 
x = bn(x, name='bn_1', act='relu', momentum=0.5) 
#32 
x = deconv(x, num_filters=512, filter_size=4, stride=2, padding=1) 
x = bn(x, name='bn_2', act='relu', momentum=0.5) 
#64 
x = deconv(x, num_filters=256, filter_size=4, stride=2, padding=1) 
x = bn(x, name='bn_3', act='relu', momentum=0.5) 
#128 
x = deconv(x, num_filters=128, filter_size=4, stride=2, padding=1) 
x = bn(x, name='bn_4', act='relu', momentum=0.5) 
#256 
x = deconv(x, num_filters=64, filter_size=4, stride=2, padding=1) 
x = bn(x, name='bn_5', act='relu', momentum=0.5) 

x = deconv(x, num_filters=2, filter_size=3, stride=1, padding=1) 
return x

7.训练网络

设置的超参数为:

  • 学习率:2e-5
  • Epoch:30
  • Mini-Batch: 10
  • 输入Tensor:[-1,1,256,256]

预训练的预测模型存放路径work/model/gray2color.inference.model

BATCH_SIZE = 30 
EPOCH_NUM = 300 

def ResNettrain(): 
gray = fluid.layers.data(name='gray', shape=[1, 256,256], dtype='float32') 
truth = fluid.layers.data(name='truth', shape=[2, 256,256], dtype='float32') 
predict = resnetImagenet(gray) 
cost = createLoss(predict=predict,truth=truth) 
return predict,cost 


'''optimizer函数''' 
def optimizer_program(): 
return fluid.optimizer.Adam(learning_rate=2e-5,beta1=0.8) 


train_reader = paddle.batch(paddle.reader.shuffle( 
reader=train(), buf_size=7500*3 
),batch_size=BATCH_SIZE) 
test_reader = paddle.batch(reader=test(), batch_size=10) 

use_cuda = True 
if not use_cuda: 
os.environ['CPU_NUM'] = str(6) 
feed_order = ['gray', 'weight'] 
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() 

main_program = fluid.default_main_program() 
star_program = fluid.default_startup_program() 

'''网络训练''' 
predict,cost = ResNettrain() 

'''优化函数''' 
optimizer = optimizer_program() 
optimizer.minimize(cost) 

exe = fluid.Executor(place) 

def train_loop(): 
gray = fluid.layers.data(name='gray', shape=[1, 256,256], dtype='float32') 
truth = fluid.layers.data(name='truth', shape=[2, 256,256], dtype='float32') 
feeder = fluid.DataFeeder( 
feed_list=['gray','truth'], place=place) 
exe.run(star_program) 

#增量训练 
fluid.io.load_persistables(exe, 'work/model/incremental/', main_program) 

for pass_id in range(EPOCH_NUM): 
step = 0 
for data in train_reader(): 
loss = exe.run(main_program, feed=feeder.feed(data),fetch_list=[cost]) 
step += 1 
if step % 1000 == 0: 
try: 
generated_img = exe.run(main_program, feed=feeder.feed(data),fetch_list=[predict]) 
plt.figure(figsize=(15,6)) 
plt.grid(False) 
for i in range(10): 
ab = generated_img[0][i] 
l = data[i][0][0] 
a = ab[0] 
b = ab[1] 
l = l[:, :, np.newaxis] 
a = a[:, :, np.newaxis].astype('float64') 
b = b[:, :, np.newaxis].astype('float64') 
lab = np.concatenate((l, a, b), axis=2) 
img = color.lab2rgb((lab)) 
img = transform.rotate(img, 270) 
img = np.fliplr(img) 
plt.grid(False) 
plt.subplot(2, 5, i + 1) 
plt.imshow(img) 
plt.axis('off') 
plt.xticks([]) 
plt.yticks([]) 
msg = 'Epoch ID={0} Batch ID={1} Loss={2}'.format(pass_id, step, loss[0][0]) 
plt.suptitle(msg,fontsize=20) 
plt.draw() 
plt.savefig('{}/{:04d}_{:04d}.png'.format('work/output_img', pass_id, step),bbox_inches='tight') 
plt.pause(0.01) 
display.clear_output(wait=True) 
except IOError: 
print(IOError) 

fluid.io.save_persistables(exe,'work/model/incremental/',main_program) 
fluid.io.save_inference_model(Params_dirname, ["gray"],[predict], exe) 
train_loop()

8.项目总结

经过按部就班的方式叙述了项目的过程。
对于训练结果虽然本项目经过抑制平均化加大离散程度提升了着色的饱和度,但最终结果仍然有较大现实差距,只能对部分场景有比较好的结果,对人造场景(如超市景观等)仍然表现力不足。
接下来准备进一步去设计损失函数,目的是让网络着色结果足以欺骗人的”直觉感觉“,而不是一味地接近真实场景

下载安装命令

## CPU版本安装命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/cpu paddlepaddle

## GPU版本安装命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/gpu paddlepaddle-gpu

本文同步分享在 博客“Redflashing”(CSDN)。
若有侵权,请联系 support@oschina.cn 删除。
本文参与“OSC源创计划”,欢迎正在阅读的你也加入,一块儿分享。