如何使用tf.data.Dataset.from_generator()将参数发送到生成器函数?

mic*_*rer 6 python python-3.x tensorflow tensorflow-datasets

我想创建一些tf.data.Dataset使用该from_generator()功能.我想向生成器函数(raw_data_gen)发送一个参数.这个想法是生成器函数将根据发送的参数产生不同的数据.通过这种方式,我希望raw_data_gen能够提供培训,验证或测试数据.

training_dataset = tf.data.Dataset.from_generator(raw_data_gen, (tf.float32, tf.uint8), ([None, 1], [None]), args=([1]))

validation_dataset = tf.data.Dataset.from_generator(raw_data_gen, (tf.float32, tf.uint8), ([None, 1], [None]), args=([2]))

test_dataset = tf.data.Dataset.from_generator(raw_data_gen, (tf.float32, tf.uint8), ([None, 1], [None]), args=([3]))
Run Code Online (Sandbox Code Playgroud)

我尝试以from_generator()这种方式调用时收到的错误消息是:

TypeError: from_generator() got an unexpected keyword argument 'args'
Run Code Online (Sandbox Code Playgroud)

这是raw_data_gen函数,虽然我不确定你是否需要这个,因为我的预感是问题是调用from_generator():

def raw_data_gen(train_val_or_test):

    if train_val_or_test == 1:        
        #For every filename collected in the list
        for filename, lab in training_filepath_label_dict.items():
            raw_data, samplerate = soundfile.read(filename)
            try: #assume the audio is stereo, ready to be sliced
                raw_data = raw_data[:,0] #raw_data is a np.array, just take first channel with slice
            except IndexError:
                pass #this must be mono audio
            yield raw_data, lab

    elif train_val_or_test == 2:
        #For every filename collected in the list
        for filename, lab in validation_filepath_label_dict.items():
            raw_data, samplerate = soundfile.read(filename)
            try: #assume the audio is stereo, ready to be sliced
                raw_data = raw_data[:,0] #raw_data is a np.array, just take first channel with slice
            except IndexError:
                pass #this must be mono audio
            yield raw_data, lab

    elif train_val_or_test == 3:
        #For every filename collected in the list
        for filename, lab in test_filepath_label_dict.items():
            raw_data, samplerate = soundfile.read(filename)
            try: #assume the audio is stereo, ready to be sliced
                raw_data = raw_data[:,0] #raw_data is a np.array, just take first channel with slice
            except IndexError:
                pass #this must be mono audio
            yield raw_data, lab

    else:
        print("generator function called with an argument not in [1, 2, 3]")
        raise ValueError()
Run Code Online (Sandbox Code Playgroud)

xdu*_*ch0 7

您需要基于新参数定义一个raw_data_gen不带任何参数的函数。您可以使用lambda关键字来执行此操作。

training_dataset = tf.data.Dataset.from_generator(lambda: raw_data_gen(train_val_or_test=1), (tf.float32, tf.uint8), ([None, 1], [None]))
...
Run Code Online (Sandbox Code Playgroud)

现在,我们将from_generator不带任何参数的函数传递给该函数,但是它将仅raw_data_gen将参数设置为1 即可起作用。您可以对验证集和测试集使用相同的方案,分别传递2和3。

  • 哦,对了,我忘记了他们添加了“ args”东西。但是,这是一个非常新的更新,显然是在1.9中引入的。也许您正在使用过时的Tensorflow版本? (2认同)