Jos*_*mon 6 python artificial-intelligence machine-learning keras generative-adversarial-network
我想使用条件GAN,目的是为一个域生成图像(标记为domain A),并通过输入来自第二个域(标记为domain B)的图像和类信息.两个域都与相同的标签信息链接(域A的每个图像都链接到图像到域B和特定标签).到目前为止我在Keras的发电机如下:
def generator_model_v2():
global BATCH_SIZE
inputs = Input((IN_CH, img_cols, img_rows))
e1 = BatchNormalization(mode=0)(inputs)
e2 = Flatten()(e1)
e3 = BatchNormalization(mode=0)(e2)
e4 = Dense(1024, activation="relu")(e3)
e5 = BatchNormalization(mode=0)(e4)
e6 = Dense(512, activation="relu")(e5)
e7 = BatchNormalization(mode=0)(e6)
e8 = Dense(512, activation="relu")(e7)
e9 = BatchNormalization(mode=0)(e8)
e10 = Dense(IN_CH * img_cols *img_rows, activation="relu")(e9)
e11 = Reshape((3, 28, 28))(e10)
e12 = BatchNormalization(mode=0)(e11)
e13 = Activation('tanh')(e12)
model = Model(input=inputs, output=e13)
return model
Run Code Online (Sandbox Code Playgroud)
到目前为止,我的生成器将来自domain A(以及输出图像的范围)的图像作为输入domain B.我想以某种方式输入输入域A的类的信息,其范围为域B生成相同类的图像.如何在展平后添加标签信息.因此,而不必输入大小1x1024有1x1025例如.我可以在Generator中为类信息使用第二个输入.如果是,我怎样才能从GAN的培训程序中调用发电机?
培训程序:
discriminator_and_classifier_on_generator = generator_containing_discriminator_and_classifier(
generator, discriminator, classifier)
generator.compile(loss=generator_l1_loss, optimizer=g_optim)
discriminator_and_classifier_on_generator.compile(
loss=[generator_l1_loss, discriminator_on_generator_loss, "categorical_crossentropy"],
optimizer="rmsprop")
discriminator.compile(loss=discriminator_loss, optimizer=d_optim) # rmsprop
classifier.compile(loss="categorical_crossentropy", optimizer=c_optim)
for epoch in range(30):
for index in range(int(X_train.shape[0] / BATCH_SIZE)):
image_batch = Y_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]
label_batch = LABEL_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE] # replace with your data here
generated_images = generator.predict(X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE])
real_pairs = np.concatenate((X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], image_batch),axis=1)
fake_pairs = np.concatenate((X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], generated_images), axis=1)
X = np.concatenate((real_pairs, fake_pairs))
y = np.concatenate((np.ones((100, 1, 64, 64)), np.zeros((100, 1, 64, 64))))
d_loss = discriminator.train_on_batch(X, y)
discriminator.trainable = False
c_loss = classifier.train_on_batch(image_batch, label_batch)
classifier.trainable = False
g_loss = discriminator_and_classifier_on_generator.train_on_batch(
X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :],
[image_batch, np.ones((100, 1, 64, 64)), label_batch])
discriminator.trainable = True
classifier.trainable = True
Run Code Online (Sandbox Code Playgroud)
代码是条件dcgans的实现(在鉴别器上添加了分类器).网络的功能是:
def generator_containing_discriminator_and_classifier(generator, discriminator, classifier):
inputs = Input((IN_CH, img_cols, img_rows))
x_generator = generator(inputs)
merged = merge([inputs, x_generator], mode='concat', concat_axis=1)
discriminator.trainable = False
x_discriminator = discriminator(merged)
classifier.trainable = False
x_classifier = classifier(x_generator)
model = Model(input=inputs, output=[x_generator, x_discriminator, x_classifier])
return model
def generator_containing_discriminator(generator, discriminator):
inputs = Input((IN_CH, img_cols, img_rows))
x_generator = generator(inputs)
merged = merge([inputs, x_generator], mode='concat',concat_axis=1)
discriminator.trainable = False
x_discriminator = discriminator(merged)
model = Model(input=inputs, output=[x_generator,x_discriminator])
return model
Run Code Online (Sandbox Code Playgroud)
首先,按照Conditional Generative Adversarial Nets 中给出的建议,您必须定义第二个输入。然后,只需连接两个输入向量并处理这个连接的向量。
def generator_model_v2():
input_image = Input((IN_CH, img_cols, img_rows))
input_conditional = Input((n_classes))
e0 = Flatten()(input_image)
e1 = Concatenate()([e0, input_conditional])
e2 = BatchNormalization(mode=0)(e1)
e3 = BatchNormalization(mode=0)(e2)
e4 = Dense(1024, activation="relu")(e3)
e5 = BatchNormalization(mode=0)(e4)
e6 = Dense(512, activation="relu")(e5)
e7 = BatchNormalization(mode=0)(e6)
e8 = Dense(512, activation="relu")(e7)
e9 = BatchNormalization(mode=0)(e8)
e10 = Dense(IN_CH * img_cols *img_rows, activation="relu")(e9)
e11 = Reshape((3, 28, 28))(e10)
e12 = BatchNormalization(mode=0)(e11)
e13 = Activation('tanh')(e12)
model = Model(input=[input_image, input_conditional] , output=e13)
return model
Run Code Online (Sandbox Code Playgroud)
然后,您还需要在训练期间将类标签传递给网络:
classifier.train_on_batch((image_batch, class_batch), label_batch)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
187 次 |
| 最近记录: |