小编Zho*_*_11的帖子

Tensorflow 2中的fit方法中使用Dataset和ndarray有什么区别?

作为 TF 的新手,我对 BatchDataset 在训练模型中的使用感到有点困惑。

让我们以 MNIST 为例。在这个分类任务中,我们可以加载数据并将x_trian、y_train的ndarray直接输入到模型中。

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train,y_train, epochs=5)

Run Code Online (Sandbox Code Playgroud)

训练结果为:

Epoch 1/5
2021-02-17 15:43:02.621749: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library cublas64_10.dll
   1/1875 [..............................] - ETA: 0s - loss: 2.2977 - accuracy: 0.0938WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0000s vs `on_train_batch_end` time: …
Run Code Online (Sandbox Code Playgroud)

keras tensorflow

11
推荐指数
1
解决办法
1733
查看次数

标签 统计

keras ×1

tensorflow ×1