Tesnorflow:如何为 tf.nn.sampled_softmax_loss 提供您自己的 `sampled_values`?

San*_*ta7 5 python tensorflow

在 tf.nn.sampled_softmax_loss 中,可选输入之一是放置您自己的样本值。我想提供我自己的样本值,以便我可以使用 float16(半精度)变量。如果sampled_values留空,Tensorflow 将使用log_uniform_candidate_sampler获取值,该值只能返回 float32。

这里是所有的输入。

tf.nn.sampled_softmax_loss(
    weights,
    biases,
    labels,
    inputs,
    num_sampled,
    num_classes,
    num_true=1,
    sampled_values=None,
    remove_accidental_hits=True,
    partition_strategy='mod',
    name='sampled_softmax_loss',
    seed=None
)
Run Code Online (Sandbox Code Playgroud)

https://www.tensorflow.org/api_docs/python/tf/nn/sampled_softmax_loss

这是他们为 sampled_values arg 提供的信息:

sampled_values:*_candidate_sampler 函数返回的 (sampled_candidates, true_expected_count, sampled_expected_count) 元组。(如果没有,我们默认为 log_uniform_candidate_sampler)

我想弄清楚如何提供这个元组。sampled_candidates, true_expected_count,究竟是什么sampled_expected_count

我知道它正在对权重和相应的偏差进行采样,所以我是否将它们放在它自己的元组中sampled_candidates?另外,我是将 int 放在矩阵中的权重位置,还是将整个嵌入本身放入?

我还查看了 Tensorflow 对负采样的数学补充,但找不到有关我的问题的任何信息https://www.tensorflow.org/extras/candidate_sampling.pdf

在我的搜索中,我在谷歌论坛上发现了这个非常相似的问题

https://groups.google.com/a/tensorflow.org/forum/#!topic/discuss/6IDJ-XAIb9M

给出的答案是

sampled_values是我们的 *candidate_sampler 类返回的元组。这些类实现了根据一些分布 Q 对对比标签(未观察到,但在训练期间使用)进行采样的方法,以用于近似训练方法,如噪声对比估计 (NCE) 和采样 Softmax。一个例子是 log_uniform_candidate_sampler,它根据对数均匀分布对标签进行采样。

您几乎不需要自己提供这些。您只需将调用结果传递给 tf.nn 模块中的 *candidate_sampler 函数(其中 * 可以是“uniform”、“log_uniform”、“zipfian_binned”等),例如

sampled_values = tf.nn.zipfian_binned_candidate_sampler(...)

如果您只是想让它工作,只需将其保留为 None,它将默认为 log_uniform_candidate_sampler(通常是一个不错的选择)。

如果您对此背后的数学感兴趣,请参阅此文档:https : //www.tensorflow.org/versions/r0.8/extras/candidate_sampling.pdf

但是要回答您的问题:对于每批观察到的标签 L 和候选抽样分布 Q,元组包括:

  • 具有实际采样对比标签 N 的张量,
  • 具有在 Q 下观察到的标签 L 的对数期望值的张量,即 log Q(L),和
  • 具有 Q 下对比标签的对数期望值的张量,即 log Q(N)。

后者是数学通过所必需的(见上面的文件)。所以 sampled_values 包含(希望清楚地滥用符号):

sampled_values = (N, log Q(L), log Q(N))。

但是,我仍然不知道如何输入值。我不确定数据类型应该是什么,以及 N 是嵌入矩阵中的 int 位置还是嵌入本身。另外,我猜 N 应该是一个值列表本身,即我们必须采样的负标签数量的大小。

我想知道我是否可以得到一个带有一些值的例子。例如,对于 3 的负采样,我是否做这样的事情?

采样值 = ([4,29, 12], [1, 1, 1], [0, 0, 0])

此外,文档说元组应该“由 *_candidate_sampler 函数返回”

这是否意味着我需要提供一个返回元组的函数,而不是元组本身?