TensorFlow 中相当于 PyTorch 中 expand() 的函数是什么?

mau*_*una 5 python tensorflow pytorch

假设我有一个 2 x 3 矩阵,我想创建一个 6 x 2 x 3 矩阵,其中第一维中的每个元素都是原始的 2 x 3 矩阵。

在 PyTorch 中,我可以这样做:

import torch
from torch.autograd import Variable
import numpy as np

x = np.array([[1, 2, 3], [4, 5, 6]])
x = Variable(torch.from_numpy(x))

# y is the desired result
y = x.unsqueeze(0).expand(6, 2, 3)
Run Code Online (Sandbox Code Playgroud)

在 TensorFlow 中执行此操作的等效方法是什么?我知道unsqueeze()相当于tf.expand_dims()但我不知道 TensorFlow 有任何相当于expand(). 我正在考虑使用tf.concat1 x 2 x 3 张量的列表,但不确定这是否是最好的方法。

小智 5

pytorch 的等效函数expand是 tensorflowtf.broadcast_to

文档:https : //www.tensorflow.org/api_docs/python/tf/broadcast_to


pat*_*_ai 0

Tensorflow 会自动广播,因此一般情况下您不需要执行任何操作。假设您有一个y'形状为 6x2x3 的形状,并且您的x形状为2x3,那么您已经可以y'*xy'+x将已经表现得好像您已经扩展了它一样。但如果出于其他原因你确实需要这样做,那么tensorflow中的命令是tile

y = tf.tile(tf.reshape(x, (1,2,3)), multiples=(6,1,1))
Run Code Online (Sandbox Code Playgroud)

文档: https: //www.tensorflow.org/api_docs/python/tf/tile