如何在pytorch中加载tfrecord?

Whi*_*sht 7 pytorch tfrecord

如何在 pytorch 中使用 tfrecord?

我已经下载了具有视频级特征的“Youtube8M”数据集,但它存储在tfrecord中。我尝试从这些文件中读取一些示例,将其转换为 numpy,然后加载到 pytorch 中。但它失败了。

    reader = YT8MAggregatedFeatureReader()
    files = tf.gfile.Glob("/Data/youtube8m/train*.tfrecord")
    filename_queue = tf.train.string_input_producer(
        files, num_epochs=5, shuffle=True)
    training_data = [
        reader.prepare_reader(filename_queue) for _ in range(1)
    ]

    unused_video_id, model_input_raw, labels_batch, num_frames = tf.train.shuffle_batch_join(
        training_data,
        batch_size=1024,
        capacity=1024 * 5,
        min_after_dequeue=1024,
        allow_smaller_final_batch=True  ,
        enqueue_many=True)

    with tf.Session() as sess:
        label_numpy = labels_batch.eval()
        print(type(label_numpy))

Run Code Online (Sandbox Code Playgroud)

但这一步却没有任何结果,只是卡了半天没有任何反应。

Mau*_*avi 0

也许这可以帮助你: TFRecord reader for PyTorch