从 C+ 生成 TFRecord 格式数据

khk*_*ens 5 c++ python protocol-buffers tensorflow tfrecord

我正在尝试使用TFRecord 格式记录 C++ 中的数据,然后在 python 中使用它来提供 TensorFlow 模型。

太长了;简单地将原始消息序列化为流并不能满足.tfrecordPythonTFRecordDataset类的格式要求。C++ 中是否有相当于 Python 的工具TfRecordWriter(在 TensorFlow 或 Google Protobuf 库中)来生成正确的.tfrecord数据?

细节:

简化的 C++ 代码如下所示:

tensorflow::Example sample;
sample.mutable_features()->mutable_feature()->operator[]("a").mutable_float_list()->add_value(1.0);

std::ofstream out;
out.open("cpp_example.tfrecord", std::ios::out | std::ios::binary);
sample.SerializeToOstream(&out);
Run Code Online (Sandbox Code Playgroud)

在Python中,为了创建TensorFlow数据,我尝试使用TFRecordDataset,但显然它需要 .tfrecord 文件中的额外页眉/页脚信息(而不是简单的序列化原始消息列表):

import tensorflow as tf
tfrecord_dataset = tf.data.TFRecordDataset(filenames="cpp_example.tfrecord")
next(tfrecord_dataset.as_numpy_iterator())
Run Code Online (Sandbox Code Playgroud)

输出:

tensorflow.python.framework.errors_impl.DataLossError: corrupted record at 0 [Op:IteratorGetNext]
Run Code Online (Sandbox Code Playgroud)

请注意,记录的二进制文件没有任何问题,因为以下代码打印了有效的输出:

import tensorflow as tf
p = open("cpp_example.tfrecord", "rb")
example = tf.train.Example.FromString(p.read())
Run Code Online (Sandbox Code Playgroud)

输出:

features {
  feature {
    key: "a"
    value {
      float_list {
        value: 1.0
      }
    }
  }
}
Run Code Online (Sandbox Code Playgroud)

通过分析我的 C++ 示例生成的二进制输出以及使用 Python 生成的输出TfRecordWriter,我观察到内容中存在额外的页眉和页脚字节。不幸的是,这些额外的字节代表的是实现细节(可能是压缩类型和一些额外的信息),并且我无法比 python 库中的某些类更深入地跟踪它,这些类只是从_pywrap_tfe.so.

这样的建议说这.tfrecord只是一个普通的谷歌 protobuf 数据。我可能缺少在哪里可以找到 protobuf 数据编写器的知识(期望将 proto 消息序列化到输出流中)?

khk*_*ens 2

事实证明tensorflow::io::RecordWriterTensorFlow C++ 库的类可以完成这项工作。

#include <tensorflow/core/lib/io/record_writer.h>

#include <tensorflow/core/platform/default/posix_file_system.h>
#include <tensorflow/core/example/example.pb.h>

// ...

// Create WritableFile and instantiate RecordWriter.
tensorflow::PosixFileSystem posixFileSystem;
std::unique_ptr<tensorflow::WritableFile> writableFile;

posixFileSystem.NewWritableFile("cpp_example.tfrecord", &writableFile);

tensorflow::io::RecordWriter recordWriter(mWritableFile.get(), tensorflow::io::RecordWriterOptions::CreateRecordWriterOptions(""));

// ...
tensorflow::Example sample;

// ...

// Serialize proto message into a buffer and record in tfrecord format.
std::string buffer;
sample.SerializeToString(&buffer);
recordWriter.WriteRecord(buffer);

Run Code Online (Sandbox Code Playgroud)

如果从TFRecord 文档中的某个位置引用此类将会很有帮助。