有没有正确的方法来子类 Tensorflow 的数据集?

Nic*_*Shu 7 python tensorflow

我一直在研究自定义 Tensorflow 数据集的不同方法,并且我习惯于查看PyTorch 的数据集,但是当我查看Tensorflow 的数据集时,我看到了这个示例:

class ArtificialDataset(tf.data.Dataset):
  def _generator(num_samples):
    # Opening the file
    time.sleep(0.03)

    for sample_idx in range(num_samples):
      # Reading data (line, record) from the file
      time.sleep(0.015)

      yield (sample_idx,)

  def __new__(cls, num_samples=3):
    return tf.data.Dataset.from_generator(
        cls._generator,
        output_signature = tf.TensorSpec(shape = (1,), dtype = tf.int64),
        args=(num_samples,)
        )
Run Code Online (Sandbox Code Playgroud)

但出现了两个问题:

  1. 看起来它所做的只是在实例化对象时,该__new__方法仅调用tf.data.Dataset.from_generator静态方法。那么为什么不直接调用它呢?为什么有一个甚至子类化的点tf.data.Dataset?有没有使用过的方法tf.data.Dataset
  2. __iter__有没有一种方法可以像数据生成器一样,在继承时填写一个方法tf.data.Dataset?我不知道,就像
class MyDataLoader(tf.data.Dataset):
  def __init__(self, path, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.data = pd.read_csv(path)

  def __iter__(self):
    for datum in self.data.iterrows():
      yield datum
Run Code Online (Sandbox Code Playgroud)

非常感谢大家!

小智 2

问题1

该示例只是将数据集与生成器封装在类中。它继承自tf.data.Datasetbecausefrom_generator()返回一个tf.data.Dataset基于对象。但是,tf.data.Dataset如示例中所示,没有使用 的方法。因此,回答问题1:是的,可以直接调用而不使用类。

问题2

是的。可以这样做。

另一种类似的方法是tf.keras.utils.Sequence这里一样使用。