Col*_*ann 7 python-3.x tensorflow tensorboard tf.keras tensorflow2.0
我已升级到Tensorflow 2.0,没有tf.summary.FileWriter("tf_graphs", sess.graph)
。我正在研究有关此问题的其他StackOverflow问题,他们说可以使用tf.compat.v1.summary etc
。当然,必须有一种在Tensorflow版本2中绘制和可视化tf.keras模型的方法。这是什么?我正在寻找以下张量板输出。谢谢!
根据docs,一旦您的模型经过训练,您就可以使用 Tensorboard 来可视化图形。
首先,定义您的模型并运行它。然后,打开 Tensorboard 并切换到 Graph 选项卡。
最小的可编译示例
此示例取自文档。首先,定义您的模型和数据。
# Relevant imports.
%load_ext tensorboard
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import datetime
from packaging import version
import tensorflow as tf
from tensorflow import keras
# Define the model.
model = keras.models.Sequential([
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(32, activation='relu'),
keras.layers.Dropout(0.2),
keras.layers.Dense(10, activation='softmax')
])
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
(train_images, train_labels), _ = keras.datasets.fashion_mnist.load_data()
train_images = train_images / 255.0
Run Code Online (Sandbox Code Playgroud)
接下来,训练您的模型。在这里,您需要为 Tensorboard 定义一个回调以用于可视化统计数据和图表。
# Define the Keras TensorBoard callback.
logdir="logs/fit/" + datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)
# Train the model.
model.fit(
train_images,
train_labels,
batch_size=64,
epochs=5,
callbacks=[tensorboard_callback])
Run Code Online (Sandbox Code Playgroud)
训练后,在你的笔记本中,运行
%tensorboard --logdir logs
Run Code Online (Sandbox Code Playgroud)
并切换到导航栏中的 Graph 选项卡:
你会看到一个看起来很像这样的图表:
You can visualize the graph of any tf.function
decorated function, but first, you have to trace its execution.
Visualizing the graph of a Keras model means to visualize it's call
method.
默认情况下,此方法未tf.function
修饰,因此您必须将模型调用包装在正确修饰的函数中并执行。
import tensorflow as tf
model = tf.keras.Sequential(
[
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(32, activation="relu"),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation="softmax"),
]
)
@tf.function
def traceme(x):
return model(x)
logdir = "log"
writer = tf.summary.create_file_writer(logdir)
tf.summary.trace_on(graph=True, profiler=True)
# Forward pass
traceme(tf.zeros((1, 28, 28, 1)))
with writer.as_default():
tf.summary.trace_export(name="model_trace", step=0, profiler_outdir=logdir)
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
925 次 |
最近记录: |