我正在遵循官方的 Tensorflow预处理层教程,我不确定我是否明白为什么在分类编码后最终会得到这些额外的列。[ 2024 年更新:默认行为现已更改,因此当前教程不再给出我在下面显示的确切结果]
这是一个精简的最小可重现示例(包括数据):
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.layers.experimental import preprocessing
import pathlib
dataset_url = 'http://storage.googleapis.com/download.tensorflow.org/data/petfinder-mini.zip'
csv_file = 'datasets/petfinder-mini/petfinder-mini.csv'
tf.keras.utils.get_file('petfinder_mini.zip', dataset_url, extract=True, cache_dir='.')
df = pd.read_csv(csv_file)
# In the original dataset "4" indicates the pet was not adopted.
df['target'] = np.where(df['AdoptionSpeed']==4, 0, 1)
# Drop un-used columns.
df = df.drop(columns=['AdoptionSpeed', 'Description'])
# A utility method to create a tf.data dataset from a Pandas Dataframe
def df_to_dataset(dataframe, shuffle=True, batch_size=32):
dataframe = dataframe.copy()
labels = dataframe.pop('target')
ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
if shuffle:
ds = ds.shuffle(buffer_size=len(dataframe))
ds = ds.batch(batch_size)
ds = ds.prefetch(batch_size)
return ds
batch_size = 5
ds = df_to_dataset(df, batch_size=batch_size)
[(train_features, label_batch)] = ds.take(1)
def get_category_encoding_layer(name, dataset, dtype, max_tokens=None):
# Create a StringLookup layer which will turn strings into integer indices
if dtype == 'string':
index = preprocessing.StringLookup(max_tokens=max_tokens)
else:
index = preprocessing.IntegerLookup(max_values=max_tokens)
# Prepare a Dataset that only yields our feature
feature_ds = dataset.map(lambda x, y: x[name])
# Learn the set of possible values and assign them a fixed integer index.
index.adapt(feature_ds)
# Create a Discretization for our integer indices.
encoder = preprocessing.CategoryEncoding(max_tokens=index.vocab_size())
#encoder = preprocessing.CategoryEncoding(max_tokens=2)
# Prepare a Dataset that only yields our feature.
feature_ds = feature_ds.map(index)
# Learn the space of possible indices.
encoder.adapt(feature_ds)
# Apply one-hot encoding to our indices. The lambda function captures the
# layer so we can use them, or include them in the functional model later.
return lambda feature: encoder(index(feature))
Run Code Online (Sandbox Code Playgroud)
所以,运行后
type_col = train_features['Type']
layer = get_category_encoding_layer('Type', ds, 'string')
layer(type_col)
Run Code Online (Sandbox Code Playgroud)
我得到的结果是:
<tf.Tensor: shape=(5, 4), dtype=float32, numpy=
array([[0., 0., 1., 0.],
[0., 0., 1., 0.],
[0., 0., 0., 1.],
[0., 0., 0., 1.],
[0., 0., 0., 1.]], dtype=float32)>
Run Code Online (Sandbox Code Playgroud)
类似于教程中显示的内容。
请注意,这是一个二元分类问题(猫/狗):
np.unique(type_col)
# array([b'Cat', b'Dog'], dtype=object)
Run Code Online (Sandbox Code Playgroud)
那么,上面结果中显示的分类编码后的 2 个额外列的逻辑是什么?它们代表什么,为什么它们是 2(而不是 1、3 或更多)?
(我完全清楚,如果我希望进行简单的独热编码,我可以简单地使用to_categorical(),但这不是这里的问题)
正如问题中已经暗示的那样,分类编码比简单的单热编码更丰富。print要查看这两列代表什么,只需在函数内的某处添加诊断即可get_category_encoding_layer():
print(index.get_vocabulary())
Run Code Online (Sandbox Code Playgroud)
那么最后命令的结果将是:
['', '[UNK]', 'Dog', 'Cat']
<tf.Tensor: shape=(5, 4), dtype=float32, numpy=
array([[0., 0., 1., 0.],
[0., 0., 1., 0.],
[0., 0., 0., 1.],
[0., 0., 0., 1.],
[0., 0., 0., 1.]], dtype=float32)>
Run Code Online (Sandbox Code Playgroud)
希望提示应该很清楚:这里额外的两列分别代表空值''和未知值'[UNK]',它们可能出现在未来(看不见的)数据中。
这实际上是根据默认参数确定的,不是 的CategoryEncoding,而是前面的StringLookup; 来自文档:
mask_token=''
oov_token='[UNK]'
Run Code Online (Sandbox Code Playgroud)
oov_token=''通过询问而不是oov_token='[UNK]'; ,您最终可以得到更严格的编码(只有 1 个额外列而不是 2 个)。将函数StringLookup中的调用替换为get_category_encoding_layer()
index = preprocessing.StringLookup(oov_token='',mask_token=None, max_tokens=max_tokens)
Run Code Online (Sandbox Code Playgroud)
之后,结果将是:
['', 'Dog', 'Cat']
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[0., 1., 0.],
[0., 1., 0.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.]], dtype=float32)>
Run Code Online (Sandbox Code Playgroud)
即只有 3 列(没有专门的一列'[UNK]')。AFAIK,这是您可以达到的最低值 - 尝试同时设置mask_token和oov_tokentoNone将导致错误。
| 归档时间: |
|
| 查看次数: |
744 次 |
| 最近记录: |