在numpy可以乘以一个三维阵列作为例子下面的2D阵列:
>>> X = np.random.randn(3,5,4) # [3,5,4]
... W = np.random.randn(5,5) # [5,5]
... out = np.matmul(W, X) # [3,5,4]
Run Code Online (Sandbox Code Playgroud)
从我的理解,np.matmul()需要W和沿的第一维播放它X。但在tensorflow它是不允许的:
>>> _X = tf.constant(X)
... _W = tf.constant(W)
... _out = tf.matmul(_W, _X)
ValueError: Shape must be rank 2 but is rank 3 for 'MatMul_1' (op: 'MatMul') with input shapes: [5,5], [3,5,4].
Run Code Online (Sandbox Code Playgroud)
那么,有没有什么等效np.matmul()上面呢tensorflow?tensorflow将 2d 张量与 3d 张量相乘的最佳实践是什么?
尝试使用tf.tile在乘法之前匹配矩阵的维度。numpy 的自动广播功能似乎没有在 tensorflow 中实现。你必须手动完成。
W_T = tf.tile(tf.expand_dims(W,0),[3,1,1])
这应该可以解决问题
import numpy as np
import tensorflow as tf
X = np.random.randn(3,4,5)
W = np.random.randn(5,5)
_X = tf.constant(X)
_W = tf.constant(W)
_W_t = tf.tile(tf.expand_dims(_W,0),[3,1,1])
with tf.Session() as sess:
print(sess.run(tf.matmul(_X,_W_t)))
Run Code Online (Sandbox Code Playgroud)
您可以tensordot改用:
tf.transpose(tf.tensordot(_W, _X, axes=[[1],[1]]),[1,0,2])
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
4657 次 |
| 最近记录: |