PyTorch:如何将Tensor的形状作为int列表

pat*_*_ai 12 python pytorch tensor

在numpy中,V.shape给出一个V维数的元组.

在tensorflow中V.get_shape().as_list()给出了V维的整数列表.

在pytorch中,V.size()给出一个大小的对象,但是如何将它转换为整数?

alv*_*vas 22

简单地说list(var.size()),例如:

>>> import torch
>>> var = torch.tensor([[1,0], [0,1]])

# Using .size function, returns a torch.Size object.
>>> var.size()
torch.Size([2, 2])
>>> type(var.size())
<class 'torch.Size'>

# Similarly, using .shape
>>> var.shape
torch.Size([2, 2])
>>> type(var.shape)
<class 'torch.Size'>
Run Code Online (Sandbox Code Playgroud)

  • 在我的例子中,“list(var.size())”的元素具有“tensor”类型,而不是“int”类型。怎么处理呢?我需要“int” (2认同)

kma*_*o23 5

如果您喜欢NumPyish语法,那么这里有tensor.shape

In [3]: ar = torch.rand(3, 3)

In [4]: ar.shape
Out[4]: torch.Size([3, 3])

# method-1
In [7]: list(ar.shape)
Out[7]: [3, 3]

# method-2
In [8]: [*ar.shape]
Out[8]: [3, 3]

# method-3
In [9]: [*ar.size()]
Out[9]: [3, 3]
Run Code Online (Sandbox Code Playgroud)

PS:请注意,这tensor.shape是的别名tensor.size(),尽管tensor.shape是所讨论的张量的属性,而是tensor.size()函数。

  • 代码的哪一部分仅在GPU机器上有效?`tensor.shape`? (9认同)