小编jin*_*ng 的帖子

如何规范张量流中模型的输入数据

我的训练数据保存在3个文件中,每个文件太大而无法放入内存中.对于每个训练示例,数据是二维的(2805行和222列,第222列用于标签)并且是数值.我想在进入训练模型之前对数据进行标准化.下面是我的input_pipeline代码,在创建数据集之前数据尚未规范化.tensorflow中是否有一些函数可以对我的情况进行规范化?

dataset = tf.data.TextLineDataset([file1, file2, file3])
# combine 2805 lines into a single example
dataset = dataset.batch(2805)

def parse_example(line_batch):
    record_defaults = [[1.0] for col in range(0, 221)]
    record_defaults.append([1])
    content = tf.decode_csv(line_batch, record_defaults = record_defaults, field_delim = '\t')
    features = tf.stack(content[0:221])
    features = tf.transpose(features)
    label = content[-1][-1]
    label = tf.one_hot(indices = tf.cast(label, tf.int32), depth = 2)
    return features, label

dataset = dataset.map(parse_example)
dataset = dataset.shuffle(1000)
# batch multiple examples
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(num_epochs)
iterator = dataset.make_one_shot_iterator()
data_batch, label_batch = …
Run Code Online (Sandbox Code Playgroud)

tensorflow

11
推荐指数
2
解决办法
1万
查看次数

标签 统计

tensorflow ×1