主要参考这篇文章 Generative Adversarial Networks, linkpython
为了方便说明和研究,我这里只是设计了一个很是简单的模型,对高斯分布样本进行生成。不过从下面的实验中,我仍是发现了一些很是有用的特色,能够加深咱们对GAN网络的了解。git
具体原理能够参考上面的文献,不过在这里仍是大概讲一下。
其实GAN的原理很是简单,它有2个子网络组成,一个是Generator,即生成网络,它以噪音样本为输入,经过学习到的权重,把噪音转变(即生成)为有意义的信号;一个是Discriminator,即判别网络,他以信号为输入(能够来自generator生成的信号,也能够是真正的信号),经过学习来判别信号的真假,并输出一个0-1之间的几率。能够把Generator比喻为一个假的印钞机,而Discriminator则是验钞机,他们两个互相竞争,使得印钞机愈来愈真,同时验钞机也愈来愈准。可是最终咱们是但愿Generator愈来愈真,而Discriminator的输出都是0.5,即难以分辨~~github
而在训练的时候,则分两个阶段进行,第一个阶段是Discriminator的学习,此时固定Generator的权重不变,只更新Discriminator的权重。loss函数是:网络
$$ \frac{1}{m}\sum_{i=1}^{m}[logD(x^i) + log(1 - D(G(z^i)))] $$app
其中m是batch_size, $x$表示真正的信号,$z$表示噪音样本。训练时分别从噪音分布和真实分布中选出m个噪音输入样本和m个真实信号样本,经过对以上的loss function最大化更新Discriminator的权重dom
第二个阶段是对Generator进行训练,此时的loss function是:函数
$$ \frac{1}{m}\sum_{i=1}^{m}[log(1 - D(G(z^i)))] $$工具
不过,此时是对loss最小化来更新Generator的权重。学习
另外,这2个阶段并非交替进行的,而是执行K次Discriminator的更新,再执行1次Generator的更新。
后面的实验结果也显示,K的选择很是关键。测试
主要工具是 python + keras
,用keras实现一些经常使用的网络特别容易,好比MLP、word2vec、LeNet、lstm等等,github上都有详细demo。可是稍微复杂些的就要费些时间本身写了。不过总体看,依然比用原生tf写要方便。并且,咱们还能够把keras当初是学习tf的参考代码,里面不少写法都很是值得借鉴。
废话很少说了,直接上代码吧:
只列出最主要的代码
# 这是针对GAN特殊设计的loss function def log_loss_discriminator(y_true, y_pred): return - K.log(K.maximum(K.epsilon(), y_pred)) def log_loss_generator(y_true, y_pred): return K.log(K.maximum(K.epsilon(), 1. - y_pred)) class GANModel: def __init__(self, input_dim, log_dir = None): ''' __tensor[0]: 定义了discriminateor的表达式, 对y进行判别,true samples __tensor[1]: 定义了generator的表达式, 对x进行生成,noise samples ''' if isinstance(input_dim, list): input_dim_y, input_dim_x = input_dim[0], input_dim[1] elif isinstance(input_dim, int): input_dim_x = input_dim_y = input_dim else: raise ValueError("input_dim should be list or interger, got %r" % input_dim) # 必须使用名字,方便后面分别输入2个信号 self.__inputs = [layers.Input(shape=(input_dim_y,), name = "y"), layers.Input(shape=(input_dim_x,), name = "x")] self.__tensors = [None, None] self.log_dir = log_dir self._discriminate_layers = [] self._generate_layers = [] self.train_status = defaultdict(list) def add_gen_layer(self, layer): self._add_layer(layer, True) def add_discr_layer(self, layer): self._add_layer(layer) def _add_layer(self, layer, for_gen=False): idx = 0 if for_gen: self._generate_layers.append(layer) idx = 1 else: self._discriminate_layers.append(layer) if self.__tensors[idx] is None: self.__tensors[idx] = layer(self.__inputs[idx]) else: self.__tensors[idx] = layer(self.__tensors[idx]) def compile_discriminateor_model(self, optimizer = optimizers.Adam()): if len(self._discriminate_layers) <= 0: raise ValueError("you need to build discriminateor model before compile it") if len(self._generate_layers) <= 0: raise ValueError("you need to build generator model before compile discriminateo model") # 经过指定trainable = False,能够freeze权重的更新。必须放在compile以前 for l in self._discriminate_layers: l.trainable = True for l in self._generate_layers: l.trainable = False discriminateor_out1 = self.__tensors[0] discriminateor_out2 = layers.Lambda(lambda y: 1. - y)(self._discriminate_generated()) # 若是输出2个信号,keras会分别在各个信号上引用loss function,而后累加,对累加的结果进行 # minimize 更新。双下划线的model是参与训练的模型。 self.__discriminateor_model = Model(self.__inputs, [discriminateor_out1, discriminateor_out2]) self.__discriminateor_model.compile(optimizer, loss = log_loss_discriminator) # 这个才是真正的discriminator model self.discriminateor_model = Model(self.__inputs[0], self.__tensors[0]) self.discriminateor_model.compile(optimizer, loss = log_loss_discriminator) if self.log_dir is not None: # 须要安装pydot和graphviz。没有的能够先注释掉 plot_model(self.__discriminateor_model, self.log_dir + "/gan_discriminateor_model.png", show_shapes = True) def compile_generator_model(self, optimizer = optimizers.Adam()): if len(self._discriminate_layers) <= 0: raise ValueError("you need to build discriminateor model before compile generator model") if len(self._generate_layers) <= 0: raise ValueError("you need to build generator model before compile it") for l in self._discriminate_layers: l.trainable = False for l in self._generate_layers: l.trainable = True out = self._discriminate_generated() self.__generator_model = Model(self.__inputs[1], out) self.__generator_model.compile(optimizer, loss = log_loss_generator) # 这个才是真正的Generator模型 self.generator_model = Model(self.__inputs[1], self.__tensors[1]) if self.log_dir is not None: plot_model(self.__generator_model, self.log_dir + "/gan_generator_model.png", show_shapes = True) def train(self, sample_list, epoch = 3, batch_size = 32, step_per = 10, plot=False): ''' step_per: 每隔几步训练一次generator,即K ''' sample_noise, sample_true = sample_list["x"], sample_list["y"] sample_count = sample_noise.shape[0] batch_count = sample_count // batch_size # 这里比较trick了,由于keras的model必需要一个y。可是gan实际上是没有y的。只好伪造一个 # 知足keras的“无理”要求 psudo_y = np.ones((batch_size, ), dtype = 'float32') if plot: # plot the real data fig = plt.figure() ax = fig.add_subplot(1,1,1) plt.ion() plt.show() for ei in range(epoch): for i in range(step_per): idx = random.randint(0, batch_count-1) batch_noise = sample_noise[idx * batch_size : (idx+1) * batch_size] idx = random.randint(0, batch_count-1) batch_sample = sample_true[idx * batch_size : (idx+1) * batch_size] self.__discriminateor_model.train_on_batch({ "y": batch_sample, "x": batch_noise}, [psudo_y, psudo_y]) idx = random.randint(0, batch_count-1) batch_noise = sample_noise[idx * batch_size : (idx+1) * batch_size] self.__generator_model.train_on_batch(batch_noise, psudo_y) if plot: gen_result = self.generator_model.predict_on_batch(batch_noise) self.train_status["gen_result"].append(gen_result) dis_result = self.discriminateor_model.predict_on_batch(gen_result) self.train_status["dis_result"].append(dis_result) freq_g, bin_g = np.histogram(gen_result, density=True) # norm to sum1 freq_g = freq_g * (bin_g[1] - bin_g[0]) bin_g = bin_g[:-1] freq_d, bin_d = np.histogram(batch_sample, density=True) freq_d = freq_d * (bin_d[1] - bin_d[0]) bin_d = bin_d[:-1] ax.plot(bin_g, freq_g, 'go-', markersize = 4) ax.plot(bin_d, freq_d, 'ko-', markersize = 8) gen1d = gen_result.flatten() dis1d = dis_result.flatten() si = np.argsort(gen1d) ax.plot(gen1d[si], dis1d[si], 'r--') if (ei+1) % 20 == 0: ax.cla() plt.title("epoch = %d" % (ei+1)) plt.pause(0.05) if plot: plt.ioff() plt.close()
只列出主要部分:从中能够看到主要模型结构和参数取值
step_per = 20 sample_size = args.batch_size * 100 # 整个测试样本集合 noise_dim = 4 signal_dim = 1 x = np.random.uniform(-3, 3, size = (sample_size, noise_dim)) y = np.random.normal(size = (sample_size, signal_dim)) samples = {"x": x, "y": y} gan = GANModel([signal_dim, noise_dim], args.log_dir) gan.add_discr_layer(layers.Dense(200, activation="relu")) gan.add_discr_layer(layers.Dense(50, activation="softmax")) gan.add_discr_layer(layers.Lambda(lambda y: K.max(y, axis=-1, keepdims=True), output_shape = (1,))) gan.add_gen_layer(layers.Dense(200, activation="relu")) gan.add_gen_layer(layers.Dense(100, activation="relu")) gan.add_gen_layer(layers.Dense(50, activation="relu")) gan.add_gen_layer(layers.Dense(signal_dim)) gan.compile_generator_model() loger.info("compile generator finished") gan.compile_discriminateor_model() loger.info("compile discriminator finished") gan.train(samples, args.epoch, args.batch_size, step_per, plot=True)
在论文中,做者就提到K对训练结果影响很大,
使用上面的step_per = 20,我获得的结果比较理想:
能够看到,最后Generator生成的数据(绿线)和真实的高斯分布(黑线)很是接近了,致使Discriminator也变得没法辨认了(p = 0.5)。
可是把step_per设为3后,结果就发散的厉害,比较难收敛:
在文章中,做者也提到,Discriminator和Generator必须匹配好,通常要多训练几回Discriminator再训练一次Generator,这是由于Discriminator是Generator的前提,若是D都没有训练好,那G的更新方向就会不许。
另外,我还发现,noise_dim对结果影响也很是大。上面的noise_dim = 4, 后面我设置为1后,最后好像很难收敛到真正的高斯分布,老是比真的高斯差那么一点。
因此,个人猜想是:Generator的输入其实能够当作是真实信号在其余维度上的映射,经过模型的学习过程,它找到了两者的映射关系,因此反过来能够认为Generator把真实信号分解到了高维空间里,此时,固然是维度越高信号被分解的越好,越容易接近真实信号。
并且,从信号拟合角度看,由于我实验中的高斯信号是非线性的,而使用的激活函数都是线性函数,若是噪音也是1维的,至关于用一堆线性函数去拟合非线性函数,这种状况必需要在一个更高的维度上才能实现。
训练一个稳定的GAN网络是一个很是复杂的过程,所幸已经有大神在这方面作了不少探索。详细请参考这里
# demo_gan.py # -*- encoding: utf8 -*- ''' GAN网络Demo ''' import os from os import path import argparse import logging import traceback import random import pickle import numpy as np import tensorflow as tf from keras import optimizers from keras import layers from keras import callbacks, regularizers, activations from keras.engine import Model from keras.utils.vis_utils import plot_model import keras.backend as K from collections import defaultdict from matplotlib import pyplot as plt import app_logger loger = logging.getLogger(__name__) # 注意pred不能为负数,由于pred是一个几率。因此最后一个激活函数的选择要注意 def log_loss_discriminator(y_true, y_pred): return - K.log(K.maximum(K.epsilon(), y_pred)) def log_loss_generator(y_true, y_pred): return K.log(K.maximum(K.epsilon(), 1. - y_pred)) class GANModel: def __init__(self, input_dim, log_dir = None): ''' __tensor[0]: 定义了discriminateor的表达式 __tensor[1]: 定义了generator的表达式 ''' # discriminateor 对y进行判别,true samples # generator 对x进行生成,noise samples if isinstance(input_dim, list): input_dim_y, input_dim_x = input_dim[0], input_dim[1] elif isinstance(input_dim, int): input_dim_x = input_dim_y = input_dim else: raise ValueError("input_dim should be list or interger, got %r" % input_dim) self.__inputs = [layers.Input(shape=(input_dim_y,), name = "y"), layers.Input(shape=(input_dim_x,), name = "x")] self.__tensors = [None, None] self.log_dir = log_dir self._discriminate_layers = [] self._generate_layers = [] self.train_status = defaultdict(list) def add_gen_layer(self, layer): self._add_layer(layer, True) def add_discr_layer(self, layer): self._add_layer(layer) def _add_layer(self, layer, for_gen=False): idx = 0 if for_gen: self._generate_layers.append(layer) idx = 1 else: self._discriminate_layers.append(layer) if self.__tensors[idx] is None: self.__tensors[idx] = layer(self.__inputs[idx]) else: self.__tensors[idx] = layer(self.__tensors[idx]) def compile_discriminateor_model(self, optimizer = optimizers.Adam()): if len(self._discriminate_layers) <= 0: raise ValueError("you need to build discriminateor model before compile it") if len(self._generate_layers) <= 0: raise ValueError("you need to build generator model before compile discriminateo model") for l in self._discriminate_layers: l.trainable = True for l in self._generate_layers: l.trainable = False discriminateor_out1 = self.__tensors[0] discriminateor_out2 = layers.Lambda(lambda y: 1. - y)(self._discriminate_generated()) self.__discriminateor_model = Model(self.__inputs, [discriminateor_out1, discriminateor_out2]) self.__discriminateor_model.compile(optimizer, loss = log_loss_discriminator) # 这个才是须要的discriminateor model self.discriminateor_model = Model(self.__inputs[0], self.__tensors[0]) self.discriminateor_model.compile(optimizer, loss = log_loss_discriminator) #if self.log_dir is not None: # plot_model(self.__discriminateor_model, self.log_dir + "/gan_discriminateor_model.png", show_shapes = True) def compile_generator_model(self, optimizer = optimizers.Adam()): if len(self._discriminate_layers) <= 0: raise ValueError("you need to build discriminateor model before compile generator model") if len(self._generate_layers) <= 0: raise ValueError("you need to build generator model before compile it") for l in self._discriminate_layers: l.trainable = False for l in self._generate_layers: l.trainable = True out = self._discriminate_generated() self.__generator_model = Model(self.__inputs[1], out) self.__generator_model.compile(optimizer, loss = log_loss_generator) # 这个才是真正须要的模型 self.generator_model = Model(self.__inputs[1], self.__tensors[1]) #if self.log_dir is not None: # plot_model(self.__generator_model, self.log_dir + "/gan_generator_model.png", show_shapes = True) def train(self, sample_list, epoch = 3, batch_size = 32, step_per = 10, plot=False): ''' step_per: 每隔几步训练一次generator ''' sample_noise, sample_true = sample_list["x"], sample_list["y"] sample_count = sample_noise.shape[0] batch_count = sample_count // batch_size psudo_y = np.ones((batch_size, ), dtype = 'float32') if plot: # plot the real data fig = plt.figure() ax = fig.add_subplot(1,1,1) plt.ion() plt.show() for ei in range(epoch): for i in range(step_per): idx = random.randint(0, batch_count-1) batch_noise = sample_noise[idx * batch_size : (idx+1) * batch_size] idx = random.randint(0, batch_count-1) batch_sample = sample_true[idx * batch_size : (idx+1) * batch_size] self.__discriminateor_model.train_on_batch({ "y": batch_sample, "x": batch_noise}, [psudo_y, psudo_y]) idx = random.randint(0, batch_count-1) batch_noise = sample_noise[idx * batch_size : (idx+1) * batch_size] self.__generator_model.train_on_batch(batch_noise, psudo_y) if plot: gen_result = self.generator_model.predict_on_batch(batch_noise) self.train_status["gen_result"].append(gen_result) dis_result = self.discriminateor_model.predict_on_batch(gen_result) self.train_status["dis_result"].append(dis_result) freq_g, bin_g = np.histogram(gen_result, density=True) # norm to sum1 freq_g = freq_g * (bin_g[1] - bin_g[0]) bin_g = bin_g[:-1] freq_d, bin_d = np.histogram(batch_sample, density=True) freq_d = freq_d * (bin_d[1] - bin_d[0]) bin_d = bin_d[:-1] ax.plot(bin_g, freq_g, 'go-', markersize = 4) ax.plot(bin_d, freq_d, 'ko-', markersize = 8) gen1d = gen_result.flatten() dis1d = dis_result.flatten() si = np.argsort(gen1d) ax.plot(gen1d[si], dis1d[si], 'r--') if (ei+1) % 20 == 0: ax.cla() plt.title("epoch = %d" % (ei+1)) plt.pause(0.05) if plot: plt.ioff() plt.close() def save_model(self, path_dir): self.generator_model.save(path_dir + "/gan_generator.h5") self.discriminateor_model.save(path_dir + "/gan_discriminateor.h5") def load_model(self, path_dir): from keras.models import load_model custom_obj = { "log_loss_discriminateor": log_loss_discriminateor, "log_loss_generator": log_loss_generator} self.generator_model = load_model(path_dir + "/gan_generator.h5", custom_obj) self.discriminateor_model = load_model(path_dir + "/gan_discriminateor.h5", custom_obj) def _discriminate_generated(self): # 必须每次从新生成一下 disc_t = self.__tensors[1] for l in self._discriminate_layers: disc_t = l(disc_t) return disc_t if __name__ == "__main__": parser = argparse.ArgumentParser("""gan model demo (gaussian sample)""") parser.add_argument("-m", "--model_dir") parser.add_argument("-log", "--log_dir") parser.add_argument("-b", "--batch_size", type = int, default = 32) parser.add_argument("-log_lvl", "--log_lvl", default = "info", metavar = "能够指定INFO,DEBUG,WARN, ERROR") parser.add_argument("-e", "--epoch", type = int, default = 10) args = parser.parse_args() log_lvl = {"info": logging.INFO, "debug": logging.DEBUG, "warn": logging.WARN, "warning": logging.WARN, "error": logging.ERROR, "err": logging.ERROR}[args.log_lvl.lower()] app_logger.init(log_lvl) loger.info("args: %r" % args) step_per = 20 sample_size = args.batch_size * 100 # 整个测试样本集合 noise_dim = 4 signal_dim = 1 x = np.random.uniform(-3, 3, size = (sample_size, noise_dim)) y = np.random.normal(size = (sample_size, signal_dim)) samples = {"x": x, "y": y} gan = GANModel([signal_dim, noise_dim], args.log_dir) gan.add_discr_layer(layers.Dense(200, activation="relu")) gan.add_discr_layer(layers.Dense(50, activation="softmax")) gan.add_discr_layer(layers.Lambda(lambda y: K.max(y, axis=-1, keepdims=True), output_shape = (1,))) gan.add_gen_layer(layers.Dense(200, activation="relu")) gan.add_gen_layer(layers.Dense(100, activation="relu")) gan.add_gen_layer(layers.Dense(50, activation="relu")) gan.add_gen_layer(layers.Dense(signal_dim)) gan.compile_generator_model() loger.info("compile generator finished") gan.compile_discriminateor_model() loger.info("compile discriminator finished") gan.train(samples, args.epoch, args.batch_size, step_per, plot=True) gen_results = gan.train_status["gen_result"] dis_results = gan.train_status["dis_result"] gen_result = gen_results[-1] dis_result = dis_results[-1] freq_g, bin_g = np.histogram(gen_result, density=True) # norm to sum1 freq_g = freq_g * (bin_g[1] - bin_g[0]) bin_g = bin_g[:-1] freq_d, bin_d = np.histogram(y, bins = 100, density=True) freq_d = freq_d * (bin_d[1] - bin_d[0]) bin_d = bin_d[:-1] plt.plot(bin_g, freq_g, 'go-', markersize = 4) plt.plot(bin_d, freq_d, 'ko-', markersize = 8) gen1d = gen_result.flatten() dis1d = dis_result.flatten() si = np.argsort(gen1d) plt.plot(gen1d[si], dis1d[si], 'r--') plt.savefig("img/gan_results.png") if not path.exists(args.model_dir): os.mkdir(args.model_dir) gan.save_model(args.model_dir) # app_logger.py import logging def init(lvl=logging.DEBUG): log_handler = logging.StreamHandler() # create formatter formatter = logging.Formatter('[%(asctime)s] %(levelname)s %(filename)s:%(funcName)s:%(lineno)d > %(message)s') log_handler.setFormatter(formatter) logging.basicConfig(level = lvl, handlers = [log_handler])