遵循Tensorflow的性能最佳实践,我使用的是NCHW数据格式,但我不确定要在tensorflow.nn.conv2d中使用的过滤器形状.
该文件称[filter_height, filter_width, in_channels, out_channels]
用于NHWC格式,但不清楚如何处理NCHW.
应该使用相同的形状?
小智 0
使用相同的过滤器形状应该可以。函数参数的唯一变化是步幅。举个例子,假设您希望您的架构能够使用两种格式,这也是推荐的:
# input -> Tensor in NCHW format
if use_nchw:
result = tf.nn.conv2d(
input=input,
filter=filter,
strides=[1, 1, stride, stride],
data_format='NCHW')
else:
input_t = tf.transpose(input, [0, 2, 3, 1]) # NCHW to NHWC
result = tf.nn.conv2d(
input=input_t,
filter=filter,
strides=[1, stride, stride, 1])
result = tf.transpose(result, [0, 3, 1, 2]) # NHWC to NCHW
Run Code Online (Sandbox Code Playgroud)