GAN Generative adversarial network 生成对抗网络

参考资料

http://t.csdnimg.cn/qCodZ

【生成对抗网络GAN原理解析】https://www.bilibili.com/video/BV1nA4m1N74j?vd_source=bd967f0d540a64617b8b612bc0f0f9a3

预备知识

GAN的全称是Generative adversarial network,中文翻译过来就是生成对抗网络。生成对抗网络其实是两个网络的组合:生成网络(Generator)负责生成模拟数据;判别网络(Discriminator)负责判断输入的数据是真实的还是生成的。生成网络要不断优化自己生成的数据让判别网络判断不出来,判别网络也要优化自己让自己判断得更准确。二者关系形成对抗,因此叫对抗网络。

pkz8XeU.png

以下内容来自chatgpt3.5。

1. GAN的基本结构

  • 生成器(Generator):负责生成假数据。生成器接收随机噪声(通常是从正态分布或均匀分布中采样的向量)作为输入,并输出假数据(如图像)。
  • 判别器(Discriminator):负责辨别数据的真假。判别器接收输入数据(可以是真实数据或生成的数据),输出一个概率,表示该数据是真实的概率。

2. GAN的训练过程

GAN的训练过程是一个两步循环:

  1. 训练判别器:使用一批真实数据和一批生成器生成的假数据来训练判别器,使其能够区分真实数据和假数据。
  2. 训练生成器:通过生成器生成一批假数据,并使用判别器的反馈(即假数据被判别器认为是真实数据的概率)来更新生成器的参数,使生成器生成的假数据越来越逼真。

具体的损失函数如下:

  • 判别器的损失函数:用于衡量判别器对真实数据和假数据的辨别能力。 LD=Expdata[logD(x)]Ezpz[log(1D(G(z)))]LD=Expdata[logD(x)]Ezpz[log(1D(G(z)))]LD=Expdata[logD(x)]Ezpz[log(1D(G(z)))]LD=−Ex∼pdata[log⁡D(x)]−Ez∼pz[log⁡(1−D(G(z)))]\mathcal{L}_D = -\mathbb{E}_{\mathbf{x} \sim p_{\text{data}}} [\log D(\mathbf{x})] - \mathbb{E}_{\mathbf{z} \sim p_{\mathbf{z}}} [\log (1 - D(G(\mathbf{z})))]LD=−Ex∼pdata[logD(x)]−Ez∼pz[log(1−D(G(z)))]

  • 生成器的损失函数:用于衡量生成器生成的假数据被判别器认为是真实数据的能力。 LG=Ezpz[logD(G(z))]LG=Ezpz[logD(G(z))]LG=Ezpz[logD(G(z))]LG=−Ez∼pz[log⁡D(G(z))]\mathcal{L}_G = -\mathbb{E}_{\mathbf{z} \sim p_{\mathbf{z}}} [\log D(G(\mathbf{z}))]LG=−Ez∼pz[logD(G(z))]

3. GAN的变种

由于GANs在实践中存在训练不稳定和模式崩溃(mode collapse)等问题,许多变种和改进方法被提出:

  • DCGAN(Deep Convolutional GANs):使用卷积神经网络来构建生成器和判别器,适用于图像数据。
  • WGAN(Wasserstein GAN):通过使用Earth-Mover距离(也称为Wasserstein距离)来改进损失函数,增强训练稳定性。
  • CycleGAN:用于图像到图像的转换,如将白天的图像转换为夜晚的图像。
  • StyleGAN:一种高级的图像生成模型,能够生成高质量的人脸图像,具有很好的控制生成图像样式的能力。

4. GAN的应用

GANs有广泛的应用,特别是在图像生成和处理方面:

  • 图像生成:生成逼真的图像,如人脸、风景、物体等。
  • 图像修复:修复损坏或缺失的图像部分。
  • 图像超分辨率:提高低分辨率图像的分辨率。
  • 图像到图像翻译:将一种图像转换为另一种图像,如黑白图像上色、素描转换为照片。
  • 数据增强:为训练数据集生成更多样本,提高模型的泛化能力。

5. 示例代码

以下是一个简单的GAN实现的Python代码示例,使用Keras库:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Reshape, Flatten, LeakyReLU
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam

# 超参数
latent_dim = 100
img_shape = (28, 28, 1)

# 生成器模型
def build_generator():
model = Sequential()
model.add(Dense(256, input_dim=latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(np.prod(img_shape), activation='tanh'))
model.add(Reshape(img_shape))
return model

# 判别器模型
def build_discriminator():
model = Sequential()
model.add(Flatten(input_shape=img_shape))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1, activation='sigmoid'))
return model

# 编译模型
optimizer = Adam(0.0002, 0.5)
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
generator = build_generator()

# 连接生成器和判别器
z = tf.keras.Input(shape=(latent_dim,))
img = generator(z)
discriminator.trainable = False
valid = discriminator(img)
combined = tf.keras.Model(z, valid)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)

# 训练GAN
(X_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
X_train = X_train / 127.5 - 1.0
X_train = np.expand_dims(X_train, axis=3)

batch_size = 64
epochs = 10000

for epoch in range(epochs):
# 训练判别器
idx = np.random.randint(0, X_train.shape[0], batch_size)
real_imgs = X_train[idx]
noise = np.random.normal(0, 1, (batch_size, latent_dim))
gen_imgs = generator.predict(noise)
d_loss_real = discriminator.train_on_batch(real_imgs, np.ones((batch_size, 1)))
d_loss_fake = discriminator.train_on_batch(gen_imgs, np.zeros((batch_size, 1)))
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

# 训练生成器
noise = np.random.normal(0, 1, (batch_size, latent_dim))
valid_y = np.ones((batch_size, 1))
g_loss = combined.train_on_batch(noise, valid_y)

# 输出训练过程
if epoch % 1000 == 0:
print(f"{epoch} [D loss: {d_loss[0]}] [G loss: {g_loss}]")