如何对 Tensorflow 数据集进行“一次热编码”?

Mar*_*ace 7 python tensorflow one-hot-encoding

新人在这里...我TF按如下方式加载数据集:

dataset = tf.data.TFRecordDataset(files)
dataset.map(extract_fn)
Run Code Online (Sandbox Code Playgroud)

数据集包含一个带有一些值的“字符串列”,我想对它们进行“one-hot”编码。extract_fn如果我有索引和深度(我现在只有一个 String 值),我可以逐条记录地做到这一点。但是,是否有 TF 功能可以为我做到这一点?IE

  • 计算不同值的数量
  • 将每个值映射到一个索引
  • 为此创建一个单热编码列

jde*_*esa 0

我认为这符合你的要求:

import tensorflow as tf
def one_hot_any(a):
    # Save original shape
    s = tf.shape(a)
    # Find unique values
    values, idx = tf.unique(tf.reshape(a, [-1]))
    # One-hot encoding
    n = tf.size(values)
    a_1h_flat = tf.one_hot(idx, n)
    # Reshape to original shape
    a_1h = tf.reshape(a_1h_flat, tf.concat([s, [n]], axis=0))
    return a_1h, values

# Test
x = tf.constant([['a', 'b'], ['a', 'd'], ['c', 'd'], ['b', 'd']])
x_1h, x_vals = one_hot_any(x)
with tf.Session() as sess:
    print(*sess.run([x_1h, x_vals]), sep='\n')
Run Code Online (Sandbox Code Playgroud)

输出:

import tensorflow as tf
def one_hot_any(a):
    # Save original shape
    s = tf.shape(a)
    # Find unique values
    values, idx = tf.unique(tf.reshape(a, [-1]))
    # One-hot encoding
    n = tf.size(values)
    a_1h_flat = tf.one_hot(idx, n)
    # Reshape to original shape
    a_1h = tf.reshape(a_1h_flat, tf.concat([s, [n]], axis=0))
    return a_1h, values

# Test
x = tf.constant([['a', 'b'], ['a', 'd'], ['c', 'd'], ['b', 'd']])
x_1h, x_vals = one_hot_any(x)
with tf.Session() as sess:
    print(*sess.run([x_1h, x_vals]), sep='\n')
Run Code Online (Sandbox Code Playgroud)

但问题是,不同的输入会产生不一致的输出,具有不同的值顺序,甚至不同的单热深度,所以我不确定它是否真的有用。