Conditional GAN (CGAN) 有条件的生成

作者: 莫烦 编辑: 莫烦 2021-04-01

学习资料:

怎么了

最开始的GAN是没办法根据标签来生成特定种类的数据的,比如我训练GAN生成猫狗的图片,但是现在我只想获得猫的图片, 经典的GAN才不管那么多,它还是会一直生成又有猫又有够的图片。 我用这个短片细说了一下有条件生成的GAN。

progress mnist

比如用上面的正常生成逻辑,是用随机噪点生成猫狗,我们能用什么方法来控制到底要生成猫呢还是生成狗呢? 这时,CGAN 就来拯救你啦~ 用一句话来总结 CGAN:把标签一起送进生成器和判别器,让他们根据标签来生成/判别结果。

training

如果从动态训练的角度来看,mnist 训练的整个过程都可以按照标签顺序来生成了,真棒。 想象哪天你有一个女神生成器,你也可以按照某个特定女神来生成了,多好~

results

怎么训练

这个很简单,只需要在训练和预测的时候,在 Generator 和 Discriminator 的输入端多给一个 input,这个 input 作用就是提供一个标签。 让 Generator 知道这张照片该生成什么,让 Discriminator 知道这张照片我应该判别是:它是否是此标签类别。

所以改动经典GAN的程度相对比较少,而我们在 mnist 数据加工的时候,还要额外做一道工序,除了拿出手写数字图片,还要将数字标签也拿出来。

秀代码

如果想直接看全部代码, 请点击这里去往我的github.

首先不同于经典GAN的一点是,除了图片,我还要提供图片对应的标签信息。

def train(gan, ds):
    for ep in range(EPOCH):
        for t, (real_img, real_img_label) in enumerate(ds):
            gan.step(real_img, real_img_label)  # 这里要额外传入图片标签啦

其次不同于经典GAN的是,他的生成器和判别器都要额外添加一个标签的输入信息。 我们还可以给标签做一个 embedding,使它拥有更丰富的信息量。

class CGAN(keras.Model):
    def _get_discriminator(self):
        img = Input(shape=self.img_shape)       # 图片
        label = Input(shape=(), dtype=tf.int32) # 标签
        label_emb = Embedding(10, 32)(label)    # 标签向量化
        emb_img = Reshape((28, 28, 1))(Dense(28*28, activation=keras.activations.relu)(label_emb))
        concat_img = tf.concat((img, emb_img), axis=3)  # 标签和图片一起输入网络
        s = keras.Sequential([
            mnist_uni_disc_cnn(input_shape=[28, 28, 2]),
            Dense(1)
        ])
        o = s(concat_img)
        model = keras.Model([img, label], o, name="discriminator")
        return model

    def _get_generator(self):
        noise = Input(shape=(self.latent_dim,))     # 噪声
        label = Input(shape=(), dtype=tf.int32)     # 标签
        # 这里我做的是 onehot 标签,你也可以把它 embedding 化
        label_onehot = tf.one_hot(label, depth=self.label_dim)  
        model_in = tf.concat((noise, label_onehot), axis=1)
        s = mnist_uni_gen_cnn((self.latent_dim+self.label_dim,))
        o = s(model_in)
        model = keras.Model([noise, label], o, name="generator")
        return model

看到这里,我假设你已经有经典GAN的基础了,所以其它的训练步骤我就省略了,不过你可以在 我的Github看到所有代码。 训练结果也是非常出色的,比较好的按照我规定的右下角标签生成了所有数字。

res

总结

CGAN 是有条件生成的一种尝试性研究,我们在后面还可以看到更多在这方面的努力,他们带来了更丰富的条件性 GAN 模型。


降低知识传递的门槛

莫烦很常从互联网上学习知识,开源分享的人是我学习的榜样。 他们的行为也改变了我对教育的态度: 降低知识传递的门槛免费 奉献我的所学正是受这种态度的影响。 通过 【赞助莫烦】 能让我感到认同,我也更有理由坚持下去。