Tensorflow - 带批量数据的输入矩阵的matmul

yok*_*oki 37 python tensorflow

我有一些数据表示input_x.它是一个未知大小的张量(应该是批量输入),每个项目的大小n.input_x经历tf.nn.embedding_lookup,让embed现在的尺寸为[?, n, m]这里m是嵌入尺寸并?指未知的批量大小.

这在这里描述:

input_x = tf.placeholder(tf.int32, [None, n], name="input_x") 
embed = tf.nn.embedding_lookup(W, input_x)
Run Code Online (Sandbox Code Playgroud)

我现在试图将输入数据中的每个样本(现在通过嵌入维度扩展)乘以矩阵变量,U我似乎无法得到如何做到这一点.

我首先尝试使用,tf.matmul但由于形状不匹配而导致错误.然后我通过扩展U和应用的维度尝试了以下内容batch_matmul(我也试过了函数tf.nn.math_ops.,结果是一样的):

U = tf.Variable( ... )    
U1 = tf.expand_dims(U,0)
h=tf.batch_matmul(embed, U1)
Run Code Online (Sandbox Code Playgroud)

这会传递初始编译,但是当应用实际数据时,我收到以下错误:

In[0].dim(0) and In[1].dim(0) must be the same: [64,58,128] vs [1,128,128]

我也知道为什么会发生这种情况 - 我复制了U现在的维度1,但是小批量大小64不合适.

如何正确地对张量矩阵输入进行矩阵乘法(对于未知的批量大小)?

Sal*_*ali 76

以前的答案已经过时了.目前tf.matmul()支持等级> 2的张量:

输入必须是矩阵(或秩> 2的张量,表示矩阵批量),具有匹配的内部维度,可能在换位之后.

tf.batch_matmul()被删除了,tf.matmul()是进行批量乘法的正确方法.从以下代码可以理解主要思想:

import tensorflow as tf
batch_size, n, m, k = 10, 3, 5, 2
A = tf.Variable(tf.random_normal(shape=(batch_size, n, m)))
B = tf.Variable(tf.random_normal(shape=(batch_size, m, k)))
tf.matmul(A, B)
Run Code Online (Sandbox Code Playgroud)

现在你将收到一个形状的张量(batch_size, n, k).这是这里发生的事情.假设你有batch_size矩阵nxmbatch_size矩阵mxk.现在,对于每对它们,你计算出nxm X mxk哪个给你一个nxk矩阵.你会拥有batch_size它们.

请注意,这样的事情也是有效的:

A = tf.Variable(tf.random_normal(shape=(a, b, n, m)))
B = tf.Variable(tf.random_normal(shape=(a, b, m, k)))
tf.matmul(A, B)
Run Code Online (Sandbox Code Playgroud)

并会给你一个形状 (a, b, n, k)

  • 如果在问题中你想要将一个矩阵与其他矩阵相乘,那么正确的方法是什么?你必须复制(平铺)单个矩阵batch_sizetimes还是有更好的方法? (6认同)
  • 这似乎没有回答原始问题 (3认同)
  • @KarlSt根据我的实验,当第一个N-2维度不匹配时,这不起作用.显然,这个命令的numpy版本支持广播,但我认为在TF中执行它的唯一方法是将单个矩阵batch_size时间平铺.我甚至试过播放转置技巧(所以看起来矩阵是[batch_size,n,m],第二个矩阵是[1,m,k]),没有运气.我不确定它可以被称为bug,但显然,这应该在TF中实现,因为它是如此常见的操作. (2认同)

P-G*_*-Gn 24

我想将一批矩阵与一批相同长度的矩阵成对地相乘

M = tf.random_normal((batch_size, n, m))
N = tf.random_normal((batch_size, m, p))

# python >= 3.5
MN = M @ N
# or the old way,
MN = tf.matmul(M, N)
# MN has shape (batch_size, n, p)
Run Code Online (Sandbox Code Playgroud)

我想将一批矩阵与一批相同长度的矢量成对地相乘

我们通过添加和删除维度来回到案例1 v.

M = tf.random_normal((batch_size, n, m))
v = tf.random_normal((batch_size, m))

Mv = (M @ v[..., None])[..., 0]
# Mv has shape (batch_size, n)
Run Code Online (Sandbox Code Playgroud)

我想将一个矩阵与一批矩阵相乘

在这种情况下,我们不能简单地1向单个矩阵添加批量维度,因为tf.matmul不在批量维度中广播.

3.1.单个矩阵位于右侧

在这种情况下,我们可以使用简单的重塑将矩阵批处理视为单个大矩阵.

M = tf.random_normal((batch_size, n, m))
N = tf.random_normal((m, p))

MN = tf.reshape(tf.reshape(M, [-1, m]) @ N, [-1, n, p])
# MN has shape (batch_size, n, p)
Run Code Online (Sandbox Code Playgroud)

3.2.单个矩阵位于左侧

这种情况比较复杂.我们可以通过转置矩阵来回到案例3.1.

MT = tf.matrix_transpose(M)
NT = tf.matrix_transpose(N)
NTMT = tf.reshape(tf.reshape(NT, [-1, m]) @ MT, [-1, p, n])
MN = tf.matrix_transpose(NTMT)
Run Code Online (Sandbox Code Playgroud)

然而,换位可能是一项昂贵的操作,并且在这里它在整批矩阵上完成两次.简单地复制M以匹配批量维度可能更好:

MN = tf.tile(M[None], [batch_size, 1, 1]) @ N
Run Code Online (Sandbox Code Playgroud)

分析将告诉哪个选项对于给定的问题/硬件组合更有效.

我想将一个矩阵与一批向量相乘

这看起来类似于情况3.2,因为单个矩阵在左侧,但它实际上更简单,因为转置向量本质上是无操作.我们最终得到了

M = tf.random_normal((n, m))
v = tf.random_normal((batch_size, m))

MT = tf.matrix_transpose(M)
Mv = v @ MT
Run Code Online (Sandbox Code Playgroud)

怎么样einsum

所有以前的乘法都可以用tf.einsum瑞士军刀写成.例如,3.2的第一个解决方案可以简单地写成

MN = tf.einsum('nm,bmp->bnp', M, N)
Run Code Online (Sandbox Code Playgroud)

但是请注意,einsum最终依靠tranposematmul参与计算.

因此,尽管einsum编写矩阵乘法是一种非常方便的方法,但它隐藏了下面的操作的复杂性 - 例如,猜测einsum表达式将转置数据的次数并不是直截了当的,因此操作的成本会很高.此外,它可能隐藏了这样一个事实,即同一操作可能有多种替代方案(见案例3.2),并且可能不一定选择更好的选择.

出于这个原因,我个人会使用上面那些明确的公式来更好地传达它们各自的复杂性.虽然如果你知道自己在做什么,并且喜欢einsum语法的简单性,那么一定要去实现它.


Sty*_*rke 18

MATMUL操作仅适用于矩阵(二维张量).以下是执行此操作的两种主要方法,均假设U为2D张量.

  1. 切片embed成二维张量和繁殖他们每个人U独立.这可能是最容易使用tf.scan()这样的:

    h = tf.scan(lambda a, x: tf.matmul(x, U), embed)
    
    Run Code Online (Sandbox Code Playgroud)
  2. 另一方面,如果效率很重要,那么重塑embed为2D张量可能会更好,因此可以使用以下单个方法完成乘法运算matmul:

    embed = tf.reshape(embed, [-1, m])
    h = tf.matmul(embed, U)
    h = tf.reshape(h, [-1, n, c])
    
    Run Code Online (Sandbox Code Playgroud)

    在哪里c是列数U.最后一次重塑将确保这h是一个3D张量,其中第0维度对应于批次,就像原始x_inputembed.

  • 这个答案是[已废弃](/sf/answers/3068081201/). (6认同)