张量流或pytorch中的分区矩阵乘法

Ami*_*mir 6 matrix multiplication matrix-multiplication tensorflow pytorch

假设我有矩阵P,其大小[4, 4]划分为4个较小的矩阵[2,2]。如何有效地将此块矩阵乘以另一个矩阵(不是分区矩阵而是较小的矩阵)?

假设我们的原始矩阵为:

P = [ 1 1 2 2
      1 1 2 2
      3 3 4 4
      3 3 4 4]
Run Code Online (Sandbox Code Playgroud)

其中分为子矩阵:

P_1 = [1 1    , P_2 = [2 2  , P_3 = [3 3   P_4 = [4 4
       1 1]            2 2]          3 3]         4 4]
Run Code Online (Sandbox Code Playgroud)

现在我们的P是:

P = [P_1 P_2
     P_3 p_4]
Run Code Online (Sandbox Code Playgroud)

下一步,我想在P和较小矩阵之间进行逐元素乘法,其大小等于子矩阵的数量:

P * [ 1 0   =   [P_1  0  = [1 1 0 0 
      0 0 ]      0    0]    1 1 0 0
                            0 0 0 0
                            0 0 0 0]    
Run Code Online (Sandbox Code Playgroud)

GZ0*_*GZ0 2

以下是基于 Tensorflow 的通用解决方案,适用于任意形状的输入矩阵p(大)和m(小),只要 的大小可被两个轴上p的大小整除。m

def block_mul(p, m):
   p_x, p_y = p.shape
   m_x, m_y = m.shape
   m_4d = tf.reshape(m, (m_x, 1, m_y, 1))
   m_broadcasted = tf.broadcast_to(m_4d, (m_x, p_x // m_x, m_y, p_y // m_y))
   mp = tf.reshape(m_broadcasted, (p_x, p_y))
   return p * mp
Run Code Online (Sandbox Code Playgroud)

测试:

import tensorflow as tf

tf.enable_eager_execution()

p = tf.reshape(tf.constant(range(36)), (6, 6))
m = tf.reshape(tf.constant(range(9)), (3, 3))
print(f"p:\n{p}\n")
print(f"m:\n{m}\n")
print(f"block_mul(p, m):\n{block_mul(p, m)}")
Run Code Online (Sandbox Code Playgroud)

输出(Python 3.7.3、Tensorflow 1.13.1):

p:
[[ 0  1  2  3  4  5]
 [ 6  7  8  9 10 11]
 [12 13 14 15 16 17]
 [18 19 20 21 22 23]
 [24 25 26 27 28 29]
 [30 31 32 33 34 35]]

m:
[[0 1 2]
 [3 4 5]
 [6 7 8]]

block_mul(p, m):
[[  0   0   2   3   8  10]
 [  0   0   8   9  20  22]
 [ 36  39  56  60  80  85]
 [ 54  57  80  84 110 115]
 [144 150 182 189 224 232]
 [180 186 224 231 272 280]]
Run Code Online (Sandbox Code Playgroud)

使用隐式广播的另一种解决方案如下:

def block_mul2(p, m):
   p_x, p_y = p.shape
   m_x, m_y = m.shape
   p_4d = tf.reshape(p, (m_x, p_x // m_x, m_y, p_y // m_y))
   m_4d = tf.reshape(m, (m_x, 1, m_y, 1))
   return tf.reshape(p_4d * m_4d, (p_x, p_y))
Run Code Online (Sandbox Code Playgroud)