对抗生成网络Gan变体集合 keras版本python
一.ACGAN(Auxiliary Classifier GAN)
https://arxiv.org/abs/1610.09585
依旧有Generator,Discriminator,可以使用MNSIT训练生成图片。git
和DCGAN的不一样:
1.增长了class类别标签参与训练,能够生成指定类别的图片github
代码引用的《Web安全之强化学习与GAN》,位置:
https://github.com/duoergun0729/3book/tree/master/code/keras-acgan.pyweb
生成器G代码:安全
def build_generator(latent_size): cnn = Sequential() cnn.add(Dense(1024, input_dim=latent_size, activation='relu')) cnn.add(Dense(128 * 7 * 7, activation='relu')) cnn.add(Reshape((128, 7, 7))) cnn.add(UpSampling2D(size=(2, 2))) cnn.add(Conv2D(256, (5, 5), padding="same", kernel_initializer="glorot_normal", activation="relu")) cnn.add(UpSampling2D(size=(2, 2))) cnn.add(Conv2D(128, (5, 5), padding="same", kernel_initializer="glorot_normal", activation="relu")) cnn.add(Conv2D(1, (2, 2), padding="same", kernel_initializer="glorot_normal", activation="tanh")) latent = Input(shape=(latent_size, )) image_class = Input(shape=(1,), dtype='int32') cls = Flatten()(Embedding(10, 100, embeddings_initializer="glorot_normal")(image_class)) #h = merge([latent, cls], mode='mul') h=add([latent, cls]) fake_image = cnn(h) return Model(inputs=[latent, image_class], outputs=[fake_image])
判别器D代码:网络
def build_discriminator(): cnn = Sequential() cnn.add(Conv2D(32, (3, 3), padding="same", strides=(2, 2), input_shape=(1, 28, 28) )) cnn.add(LeakyReLU()) cnn.add(Dropout(0.3)) cnn.add(Conv2D(64, (3, 3), padding="same", strides=(1, 1))) cnn.add(LeakyReLU()) cnn.add(Dropout(0.3)) cnn.add(Conv2D(128, (3, 3), padding="same", strides=(2, 2))) cnn.add(LeakyReLU()) cnn.add(Dropout(0.3)) cnn.add(Conv2D(256, (3, 3), padding="same", strides=(1, 1))) cnn.add(LeakyReLU()) cnn.add(Dropout(0.3)) cnn.add(Flatten()) image = Input(shape=(1, 28, 28)) features = cnn(image) fake = Dense(1, activation='sigmoid', name='generation')(features) aux = Dense(10, activation='softmax', name='auxiliary')(features) return Model(inputs=[image], outputs=[fake, aux])
训练图:
ide