矩阵乘法与 Python 中的对象数组

Ove*_*olt 5 python arrays numpy multiplication

我想知道如何在 numpy 中使用dtype=object. 我已经同态加密的被封装在一个类的数字Ciphertext为我所重写基本的数学运算符,例如__add____mul__等等。

我创建了 numpy 数组,其中每个条目都是我的类的一个实例,Ciphertext并且 numpy 了解如何广播加法和乘法运算就好了。

    encryptedInput = builder.encrypt_as_array(np.array([6,7])) # type(encryptedInput) is <class 'numpy.ndarray'>
    encryptedOutput = encryptedInput + encryptedInput
    builder.decrypt(encryptedOutput)                           # Result: np.array([12,14])
Run Code Online (Sandbox Code Playgroud)

但是,numpy 不会让我做矩阵乘法

out = encryptedInput @ encryptedInput # TypeError: Object arrays are not currently supported
Run Code Online (Sandbox Code Playgroud)

考虑到加法和乘法有效,我不太明白为什么会发生这种情况。我想这与 numpy 无法知道对象的形状有关,因为它可能是一个列表或一些花哨的东西。

天真的解决方案:我可以编写自己的类来扩展ndarray和覆盖__matmul__操作,但我可能会失去性能,而且这种方法需要实现广播等,所以我基本上会重新发明轮子,以获得正确的东西现在。

问题:如何在数组上使用 numpy 提供的标准矩阵乘法,dtype=objects其中对象的行为与数字完全相同?

先感谢您!

Ove*_*olt 3

无论出于何种原因,matmul 不起作用,但 tensordot 函数按预期工作。

encryptedInput = builder.encrypt_as_array(np.array([6,7]))
out = np.tensordot(encryptedInput, encryptedInput, axes=([1,0])) 
    # Correct Result: [[ 92. 105.]
    #                  [120. 137.]]
Run Code Online (Sandbox Code Playgroud)

现在只是调整轴很麻烦。我仍然想知道这是否真的比使用 for 循环的简单实现更快。