MNIST 的最小扩散模型 (DDIM)

Edv*_*Beq 5 python neural-network keras tensorflow

为了学习的目的,我为 MNIST 数据集创建了一个最小的 DDIM。除了扩散数学之外的一切我都认为是“额外的”。

这是额外的列表:

  • 优网
  • 位置嵌入
  • 扩散时间表
  • 数据集标准化
  • 指数移动平均线

使用最小示例的原因是因为我不理解这些其他技巧的贡献。因此,如果我从更简单的事情开始 - 我可以看到额外优化的贡献。

简化网络,IMO 也是一个泛化步骤,因此该方法可以应用于其他问题。

该代码借鉴自这个伟大的 Keras 示例:https ://keras.io/examples/generative/ddim/

只是澄清一下,要回答这个问题,我们可以提供一个比 u-net 更简单的网络,我们可以识别一些数字,或者解释为什么我们需要 u-net

我在 Keras 博客的架构技巧下读到了原作者写的一些有趣的内容。它说:

跳过连接:在网络架构中使用跳过连接绝对至关重要,没有它们,模型将无法以良好的性能学习去噪。

代码 - 更新以删除所有额外内容

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import os

print("tf version: ", tf.__version__)

# data
diffusion_steps = 20
image_size = 28

# sampling
min_signal_rate = 0.02
max_signal_rate = 0.95

# optimization
batch_size = 64
num_epochs = 1000
learning_rate = 1e-3

embedding_dims = 32
embedding_max_frequency = 1000.0


x0 = tf.keras.Input(shape=(28, 28, 1))
t0 = tf.keras.Input(shape=(1, 1, 1))

combined = tf.keras.layers.Add()([x0, t0])

x = tf.keras.layers.Flatten()(combined)
x = tf.keras.layers.Dense(7 * 7 * 64, activation="relu")(x)
x = tf.keras.layers.Reshape((7, 7, 64))(x)
x = tf.keras.layers.Conv2DTranspose(
    64, 3, activation="relu", strides=2, padding="same"
)(x)
x = tf.keras.layers.Conv2DTranspose(
    32, 3, activation="relu", strides=2, padding="same"
)(x)
output = tf.keras.layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
network = tf.keras.Model(inputs=[x0, t0], outputs=output)
print(network.summary())


class DiffusionModel(tf.keras.Model):
    def __init__(self, network):
        super().__init__()

        self.normalizer = tf.keras.layers.Normalization()
        self.network = network

    def compile(self, **kwargs):
        super().compile(**kwargs)
        self.noise_loss_tracker = tf.keras.metrics.Mean(name="n_loss")
        self.image_loss_tracker = tf.keras.metrics.Mean(name="i_loss")

    @property
    def metrics(self):
        return [self.noise_loss_tracker, self.image_loss_tracker]

    def denormalize(self, images):
        return tf.clip_by_value(images, 0.0, 1.0)

    # predictive stage
    def denoise(self, noisy_images, times, training):
        # predict noise component and calculate the image component using it
        with tf.GradientTape() as tape:
            tape.watch(noisy_images)
            pred_noises = self.network([noisy_images, times**2], training=training)
            gradients = tape.gradient(pred_noises, noisy_images)
        pred_images = noisy_images - pred_noises - gradients
        return pred_noises, pred_images

    def reverse_diffusion(self, initial_noise, steps):
        # reverse diffusion = sampling
        batch = initial_noise.shape[0]
        step_size = 1.0 / steps
        
        next_noisy_images = initial_noise
        next_diffusion_times = tf.ones((batch, 1, 1, 1))

        for step in range(diffusion_steps):
            noisy_images = next_noisy_images
            diffusion_times = next_diffusion_times

            pred_noises, pred_images = self.denoise(
                noisy_images, diffusion_times, training=False
            )

            # this new noisy image will be used in the next step
            next_diffusion_times = diffusion_times - step_size
            next_noisy_images = pred_images + pred_noises
        return pred_images

    def generate(self, num_images, steps):
        # noise -> images -> denormalized images
        initial_noise = tf.random.normal(shape=(num_images, image_size, image_size, 1))
        generated_images = self.reverse_diffusion(initial_noise, steps)
        return generated_images

    def train_step(self, images):
        noises = tf.random.normal(shape=(batch_size, image_size, image_size, 1))
        diffusion_times = tf.random.uniform(
            shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0
        )

        with tf.GradientTape(persistent=True) as tape:
            noisy_images = images + noises
            # train the network to separate noisy images to their components
            pred_noises, pred_images = self.denoise(
                noisy_images, diffusion_times, training=True
            )

            noise_loss = self.loss(noises, pred_noises)  # used for training
            image_loss = self.loss(images, pred_images)  # only used as metric
            
            # total_loss = noise_loss + image_loss

        gradients = tape.gradient(noise_loss, self.network.trainable_weights)
        self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights))

        self.noise_loss_tracker.update_state(noise_loss)
        self.image_loss_tracker.update_state(image_loss)
        return {m.name: m.result() for m in self.metrics}

    def plot_images(
        self,
        epoch=None,
        logs=None,
        num_rows=3,
        num_cols=6,
        write_to_file=True,
        output_dir="output",
    ):
        # plot random generated images for visual evaluation of generation quality
        generated_images = self.generate(
            num_images=num_rows * num_cols,
            steps=diffusion_steps,
        )

        plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0))
        for row in range(num_rows):
            for col in range(num_cols):
                index = row * num_cols + col
                plt.subplot(num_rows, num_cols, index + 1)
                plt.imshow(generated_images[index])
                plt.axis("off")

        plt.tight_layout()

        if write_to_file:
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            if epoch is not None:
                filename = os.path.join(
                    output_dir, "image_epoch_{:04d}.png".format(epoch)
                )
            else:
                import time

                timestr = time.strftime("%Y%m%d-%H%M%S")
                filename = os.path.join(output_dir, "image_{}.png".format(timestr))
            plt.savefig(filename)
        else:
            plt.show()

        plt.close()


# create and compile the model
model = DiffusionModel(network)
model.compile(
    optimizer=tf.keras.optimizers.experimental.AdamW(learning_rate=learning_rate),
    # loss=tf.keras.losses.mean_squared_error,
    loss=tf.keras.losses.mean_absolute_error,
)
# pixelwise mean absolute error is used as loss

# save the best model based on the noise loss
checkpoint_path = "checkpoints/diffusion_model"
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    save_weights_only=True,
    monitor="n_loss",
    mode="min",
    save_best_only=True,
)

(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255

dataset = tf.data.Dataset.from_tensor_slices(mnist_digits)
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.shuffle(10000, reshuffle_each_iteration=True)


# run training and plot generated images periodically
model.fit(
    dataset,
    epochs=num_epochs,
    batch_size=batch_size,
    callbacks=[
        tf.keras.callbacks.LambdaCallback(on_epoch_end=model.plot_images),
        checkpoint_callback,
    ],
)

# load the best model and generate images
model.load_weights(checkpoint_path)
model.plot_images(write_to_file=False)
Run Code Online (Sandbox Code Playgroud)

更新1

我删除了上面列出的所有“额外内容”,包括调度程序和标准化器 - 并更改了去噪方法以获取预测噪声相对于噪声图像的导数。

def denoise(self, noisy_images, times, training):
    # predict noise component and calculate the image component using it
    with tf.GradientTape() as tape:
        tape.watch(noisy_images)
        pred_noises = self.network([noisy_images, times**2], training=training)
        gradients = tape.gradient(pred_noises, noisy_images)
    pred_images = noisy_images - pred_noises - gradients
    return pred_noises, pred_images
    
Run Code Online (Sandbox Code Playgroud)

这样做的结果是你可以看到那里有一些东西(数字)而不仅仅是噪音。因此,这种希望的暗示和下面好心人提出的改进带来了更多的改进。

修复

  • 删除了 @xdurch0 指出的注释块
  • 修复了 @Maciej Skorski 指出的非规范化方法
  • 按照 @Daraan 的建议和原始代码作者的评论,在 update-2 中添加了跳过连接

更新2

造成最大差异的最大变化是添加了跳跃连接。输入 x 和 t 相乘,展平并形成具有“线性”激活函数的密集层。这非常重要。然后将密集层添加到网络的输出中。我认为这有助于梯度消失,但可能还有更多。

x0 = tf.keras.Input(shape=(28, 28, 1))
t0 = tf.keras.Input(shape=(1, 1, 1))

combined = tf.keras.layers.Add()([x0, t0])
x = tf.keras.layers.Flatten()(combined)
x = tf.keras.layers.Dense(784, activation="linear")(x)
x1 = tf.keras.layers.Reshape((28, 28, 1))(x)

x = tf.keras.layers.Dense(7 * 7 * 64, activation="relu")(x)
x = tf.keras.layers.Reshape((7, 7, 64))(x)
x = tf.keras.layers.Conv2DTranspose(
    64, 3, activation="relu", strides=2, padding="same"
)(x)
x = tf.keras.layers.Conv2DTranspose(
    32, 3, activation="relu", strides=2, padding="same"
)(x)
x = tf.keras.layers.Conv2DTranspose(1, 3, activation="relu", padding="same")(x)

output = tf.keras.layers.Add()([x, x1])
network = tf.keras.Model(inputs=[x0, t0], outputs=output) 
Run Code Online (Sandbox Code Playgroud)

凭借这个相对较小的网络和“梯度技巧”,我能够到达这里,这远远超出了我最初的目标。

在此输入图像描述

接下来,我添加了标准化和调度程序。归一化强调像素 - 使它们在较高值下更加密集。调度程序有助于训练。所以最终结果如下:

  1. 一次跳过连接、标准化、调度程序和“梯度技巧”

在此输入图像描述

  1. 一次跳过连接、归一化、调度程序,没有“梯度技巧”——类似的训练参数。

在此输入图像描述

我认为这些结果很棒。在图像生成中需要高保真度,但像这样的好的结果可以在其他领域有用。我通过反复试验得出的渐变技巧确实让我感到惊讶。我很想听听任何碰巧看到这一点的研究人员或学者的任何想法。

Mac*_*ski 0

总评:自上而下

我建议从上到下。

开始从工作实现中消除组件,看看后果。而不是将它们添加到无法正常工作的实现之上。

keras-team 的官方实现为例,它运行在Colab 上,可以在 GPU 上快速训练。我做了一些小的修改来适应 MNIST这个修改后的笔记本证明它可以训练

现在,我们要删除什么?

当前代码的问题

注意(代码中的逻辑错误)我看到一些错误,例如denormalize没有正确反转(忽略标准缩放)。与我引用的笔记本进行比较。

与我引用的笔记本相比, NOTE(性能不佳的代码模式) tf.dataset可以更有效地使用。

注意(错误的架构)建议的网络没有与 耦合DiffusionModel,例如不利用噪声率。

注意(性能提示)对于计算机视觉或大型语言模型,我建议在 Colab 或 Kaggle 上运行以获得一些免费的 GPU 带宽。

工作代码

我修复了扩散代码的问题并连接到该公共笔记本ResUNet表明它运行良好。因此,我们将问题缩小到网络实现上:它也可能有错误或不合适。