小编Hac*_*ker的帖子

当传递无限重复的数据集时,必须指定“steps_per_epoch”参数

我正在尝试使用谷歌的示例,但使用我自己的数据集:

https://github.com/tensorflow/examples/blob/master/tensorflow_examples/lite/model_customization/demo/text_classification.ipynb

我创建了一个类似于代码中下载的文件夹,其中包含训练和测试文件夹以及 txt 文件。

就我而言,data_path 如下: data_path = '/Users/developer/.keras/datasets/chat'

每当我尝试运行时,它model = text_classifier.create(train_data)都会抛出错误, ValueError: When passing an infinitely repeating dataset, you must specify the `steps_per_epoch` argument. 这是什么意思以及我应该在哪里寻找问题?


import numpy as np
import os
import tensorflow as tf
assert tf.__version__.startswith('2')

from tensorflow_examples.lite.model_customization.core.data_util.text_dataloader import TextClassifierDataLoader
from tensorflow_examples.lite.model_customization.core.model_export_format import ModelExportFormat
import tensorflow_examples.lite.model_customization.core.task.text_classifier as text_classifier


# data_path = tf.keras.utils.get_file(
#       fname='aclImdb',
#       origin='http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz',
#       untar=True)

data_path = '/Users/developer/.keras/datasets/chat'

train_data = TextClassifierDataLoader.from_folder(os.path.join(data_path, 'train'), class_labels=['greeting', 'goodbye'])
test_data = TextClassifierDataLoader.from_folder(os.path.join(data_path, 'test'), shuffle=False)

model = text_classifier.create(train_data) …
Run Code Online (Sandbox Code Playgroud)

python tensorflow tensorflow-datasets tensorflow-lite tensorflow2.0

6
推荐指数
1
解决办法
1万
查看次数