如何顺序组合2个张量流模型?

Upe*_*ahu 6 python machine-learning neural-network tensorflow

我有 2 个 Tensorflow 模型,它们都具有相同的架构(Unet-3d)。我当前的流程是:

预处理 -> 模型 1 预测 -> 一些操作 -> 模型 2 预测 -> 后处理

中间的操作可以在TF中完成。我们能否将这两个模型与 1 个 TF 图之间的操作结合起来,使流程看起来像这样:

预处理->模型1+2->后处理

谢谢。

Sri*_*adi 14

您可以使用tf.keras函数式 api 来实现此目的,这是一个玩具示例。

\n\n
import tensorflow as tf\nprint(\'TensorFlow:\', tf.__version__)\n\ndef preprocessing(tensor):\n    # preform your operations\n    return tensor\n\ndef some_operations(model_1_prediction):\n    # preform your operations\n    # assuming your operations result in a tensor\n    # which has shape matching with model_2\'s input\n    tensor = model_1_prediction\n    return tensor\n\ndef post_processing(tensor):\n    # preform your operations\n    return tensor\n\ndef get_model(name):\n    inp = tf.keras.Input(shape=[256, 256, 3])\n    x = tf.keras.layers.Conv2D(64, 3, 1, \'same\')(inp)\n    x = tf.keras.layers.Conv2D(256, 3, 1, \'same\')(x)\n    x = tf.keras.layers.Conv2D(512, 3, 1, \'same\')(x)\n    x = tf.keras.layers.Conv2D(64, 3, 1, \'same\')(x)\n    x = tf.keras.layers.Conv2D(3, 3, 1, \'same\')(x)\n    # num_filters is set to 3 to make sure model_1\'s output\n    # matches model_2\'s input.\n    output = tf.keras.layers.Activation(\'sigmoid\')(x)\n    return tf.keras.Model(inp, output, name=name)\n\nmodel_1 = get_model(\'model-1\')\nmodel_2 = get_model(\'model-2\')\n\n\nx = some_operations(model_1.output)\nout = model_2(x)\nmodel_1_2 = tf.keras.Model(model_1.input, out, name=\'model-1+2\')\n\nmodel_1_2.summary()\n
Run Code Online (Sandbox Code Playgroud)\n\n

输出:

\n\n
TensorFlow: 2.1.0-rc0\nModel: "model-1+2"\n_________________________________________________________________\nLayer (type)                 Output Shape              Param #   \n=================================================================\ninput_1 (InputLayer)         [(None, 256, 256, 3)]     0         \n_________________________________________________________________\nconv2d (Conv2D)              (None, 256, 256, 64)      1792      \n_________________________________________________________________\nconv2d_1 (Conv2D)            (None, 256, 256, 256)     147712    \n_________________________________________________________________\nconv2d_2 (Conv2D)            (None, 256, 256, 512)     1180160   \n_________________________________________________________________\nconv2d_3 (Conv2D)            (None, 256, 256, 64)      294976    \n_________________________________________________________________\nconv2d_4 (Conv2D)            (None, 256, 256, 3)       1731      \n_________________________________________________________________\nactivation (Activation)      (None, 256, 256, 3)       0         \n_________________________________________________________________\nmodel-2 (Model)              (None, 256, 256, 3)       1626371   \n=================================================================\nTotal params: 3,252,742\nTrainable params: 3,252,742\nNon-trainable params: 0\n_________________________________________________________________\n\xe2\x80\x8b\n
Run Code Online (Sandbox Code Playgroud)\n