小编Ale*_*o B的帖子

使用TensorFlow中的3D卷积进行批量标准化

我正在实施一个依赖于3D卷积的模型(用于类似于动作识别的任务),我想使用批量标准化(参见[Ioffe&Szegedy 2015]).我找不到任何专注于3D转换的教程,因此我在这里做一个简短的教程,我想和你一起回顾.

下面的代码引用TensorFlow r0.12并且它显式实例变量 - 我的意思是我没有使用tf.contrib.learn,除了tf.contrib.layers.batch_norm()函数.我这样做是为了更好地了解事情如何在幕后工作并具有更多的实现自由(例如,可变摘要).

通过首先编写完全连接层的示例,然后进行2D卷积,最后编写3D情况,我将顺利地进入3D卷积情况.在浏览代码时,如果你能检查一切是否正确完成会很好 - 代码运行,但我不能100%确定应用批量规范化的方式.我以更详细的问题结束这篇文章.

import tensorflow as tf

# This flag is used to allow/prevent batch normalization params updates
# depending on whether the model is being trained or used for prediction.
training = tf.placeholder_with_default(True, shape=())
Run Code Online (Sandbox Code Playgroud)

完全连接(FC)外壳

# Input.
INPUT_SIZE = 512
u = tf.placeholder(tf.float32, shape=(None, INPUT_SIZE))

# FC params: weights only, no bias as per [Ioffe & Szegedy 2015].
FC_OUTPUT_LAYER_SIZE = 1024
w = tf.Variable(tf.truncated_normal(
    [INPUT_SIZE, FC_OUTPUT_LAYER_SIZE], dtype=tf.float32, stddev=1e-1))

# Layer output …
Run Code Online (Sandbox Code Playgroud)

python machine-learning deep-learning tensorflow batch-normalization

17
推荐指数
1
解决办法
4718
查看次数

在TensorFlow中没有广播tf.matmul

我有一个问题,我一直在努力.它与tf.matmul()广播有关,也没有广播.

我在https://github.com/tensorflow/tensorflow/issues/216上发现了类似的问题,但tf.batch_matmul()对我的案例看起来并不像是一个解决方案.

我需要将输入数据编码为4D张量: X = tf.placeholder(tf.float32, shape=(None, None, None, 100)) 第一个维度是批次的大小,第二个维度是批次中的条目数量.您可以将每个条目想象为多个对象的组合(第三维).最后,每个对象由100个浮点值的向量描述.

请注意,我对第二维和第三维使用了None,因为实际大小可能会在每个批次中发生变化.但是,为简单起见,让我们用实际数字来形成张量: X = tf.placeholder(tf.float32, shape=(5, 10, 4, 100))

这些是我计算的步骤:

  1. 计算100个浮点值(例如,线性函数)W = tf.Variable(tf.truncated_normal([100, 50], stddev=0.1)) Y = tf.matmul(X, W) 问题的每个向量的函数 :tf.matmul()使用tf.batch_matmul() Y的预期形状没有广播和没有成功:(5,10,4,50)

  2. 对批次的每个条目应用平均池(在每个条目的对象上): Y_avg = tf.reduce_mean(Y, 2) Y_avg的预期形状:(5,10,50)

我预计tf.matmul()会支持广播.然后我发现tf.batch_matmul(),但它看起来仍然不适用于我的情况(例如,W需要至少有3个维度,不清楚为什么).

顺便说一句,上面我使用了一个简单的线性函数(其权重存储在W中).但在我的模型中,我有一个深层网络.因此,我遇到的更普遍的问题是自动计算张量的每个切片的函数.这就是为什么我预期tf.matmul()会有广播行为(如果是这样,tf.batch_matmul()甚至可能根本不需要).

期待向您学习!阿莱西奥

broadcasting tensorflow

11
推荐指数
1
解决办法
6559
查看次数