use*_*389 4 python numpy pytorch
给定 numpy(或 pytorch)中的二维张量,我可以同时沿所有维度进行部分切片,如下所示:
>>> import numpy as np
>>> a = np.arange(2*3).reshape(2,3)
array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
>>> a[1:,1:]
array([[ 5, 6, 7],
[ 9, 10, 11]])
Run Code Online (Sandbox Code Playgroud)
如果我在实现时不知道维数,如何实现相同的切片模式,无论张量中的维数如何?(即我想要a[1:]
如果a
只有一维、a[1:,1:]
二维、a[1:,1:,1:]
三维等等)
如果我可以用如下所示的一行代码来完成它,那就太好了,但这是无效的:
a[(1:,) * len(a.shape)] # SyntaxError: invalid syntax
Run Code Online (Sandbox Code Playgroud)
我对适用于 pytorch 张量的解决方案特别感兴趣(只需将上面的 numpy 替换为 torch,示例是相同的),但我认为如果该解决方案同时适用于 numpy 和 pytorch,那么它可能也是最好的。
答案:制作切片对象 的元组可以解决这个问题:
a[(slice(1,None),) * len(a.shape)]
Run Code Online (Sandbox Code Playgroud)
说明:
slice
是一个内置的 python 类(不依赖于 numpy 或 pytorch),它提供了用于描述切片的下标表示法的替代方法。 另一个问题的答案建议使用它作为在 python 变量中存储切片信息的方式。python 术语表指出
括号(下标)表示法在内部使用切片对象。
由于numpy ndarrays和pytorch 张量__getitem__
的方法支持切片的多维索引,因此它们也必须支持切片对象的多维索引,因此我们可以将这些切片创建一个具有正确长度的元组。
顺便说一句,您可以通过创建一个虚拟类来了解 python 如何使用切片对象,如下所示,然后对其进行切片:
class A(object):
def __getitem__(self, ix):
return ix
print(A()[5]) # 5
print(A()[1:]) # slice(1, None, None)
print(A()[1:,1:]) # (slice(1, None, None), slice(1, None, None))
print(A()[1:,slice(1,None)]) # (slice(1, None, None), slice(1, None, None))
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
4557 次 |
最近记录: |