Conditional GAN (CGAN) 有条件的生成
学习资料:
- 我制作的GAN简介短片
- 我制作的让GAN生成你想要的
- 论文 Conditional Generative Adversarial Nets
- 本节代码
- 代码有我自己定义的依赖utils.py, visual.py, mnist_ds.py
怎么了 ¶
最开始的GAN是没办法根据标签来生成特定种类的数据的,比如我训练GAN生成猫狗的图片,但是现在我只想获得猫的图片, 经典的GAN才不管那么多,它还是会一直生成又有猫又有够的图片。 我用这个短片细说了一下有条件生成的GAN。
比如用上面的正常生成逻辑,是用随机噪点生成猫狗,我们能用什么方法来控制到底要生成猫呢还是生成狗呢? 这时,CGAN 就来拯救你啦~ 用一句话来总结 CGAN:把标签一起送进生成器和判别器,让他们根据标签来生成/判别结果。
如果从动态训练的角度来看,mnist 训练的整个过程都可以按照标签顺序来生成了,真棒。 想象哪天你有一个女神生成器,你也可以按照某个特定女神来生成了,多好~
怎么训练 ¶
这个很简单,只需要在训练和预测的时候,在 Generator 和 Discriminator 的输入端多给一个 input,这个 input 作用就是提供一个标签。 让 Generator 知道这张照片该生成什么,让 Discriminator 知道这张照片我应该判别是:它是否是此标签类别。
所以改动经典GAN的程度相对比较少,而我们在 mnist 数据加工的时候,还要额外做一道工序,除了拿出手写数字图片,还要将数字标签也拿出来。
秀代码 ¶
如果想直接看全部代码, 请点击这里去往我的github.
首先不同于经典GAN的一点是,除了图片,我还要提供图片对应的标签信息。
def train(gan, ds):
for ep in range(EPOCH):
for t, (real_img, real_img_label) in enumerate(ds):
gan.step(real_img, real_img_label) # 这里要额外传入图片标签啦
其次不同于经典GAN的是,他的生成器和判别器都要额外添加一个标签的输入信息。 我们还可以给标签做一个 embedding,使它拥有更丰富的信息量。
class CGAN(keras.Model):
def _get_discriminator(self):
img = Input(shape=self.img_shape) # 图片
label = Input(shape=(), dtype=tf.int32) # 标签
label_emb = Embedding(10, 32)(label) # 标签向量化
emb_img = Reshape((28, 28, 1))(Dense(28*28, activation=keras.activations.relu)(label_emb))
concat_img = tf.concat((img, emb_img), axis=3) # 标签和图片一起输入网络
s = keras.Sequential([
mnist_uni_disc_cnn(input_shape=[28, 28, 2]),
Dense(1)
])
o = s(concat_img)
model = keras.Model([img, label], o, name="discriminator")
return model
def _get_generator(self):
noise = Input(shape=(self.latent_dim,)) # 噪声
label = Input(shape=(), dtype=tf.int32) # 标签
# 这里我做的是 onehot 标签,你也可以把它 embedding 化
label_onehot = tf.one_hot(label, depth=self.label_dim)
model_in = tf.concat((noise, label_onehot), axis=1)
s = mnist_uni_gen_cnn((self.latent_dim+self.label_dim,))
o = s(model_in)
model = keras.Model([noise, label], o, name="generator")
return model
看到这里,我假设你已经有经典GAN的基础了,所以其它的训练步骤我就省略了,不过你可以在 我的Github看到所有代码。 训练结果也是非常出色的,比较好的按照我规定的右下角标签生成了所有数字。
总结 ¶
CGAN 是有条件生成的一种尝试性研究,我们在后面还可以看到更多在这方面的努力,他们带来了更丰富的条件性 GAN 模型。