pat*_*_ai 12 python pytorch tensor
在numpy中,V.shape给出一个V维数的元组.
在tensorflow中V.get_shape().as_list()给出了V维的整数列表.
在pytorch中,V.size()给出一个大小的对象,但是如何将它转换为整数?
alv*_*vas 22
简单地说list(var.size()),例如:
>>> import torch
>>> var = torch.tensor([[1,0], [0,1]])
# Using .size function, returns a torch.Size object.
>>> var.size()
torch.Size([2, 2])
>>> type(var.size())
<class 'torch.Size'>
# Similarly, using .shape
>>> var.shape
torch.Size([2, 2])
>>> type(var.shape)
<class 'torch.Size'>
Run Code Online (Sandbox Code Playgroud)
如果您喜欢NumPyish语法,那么这里有tensor.shape。
In [3]: ar = torch.rand(3, 3)
In [4]: ar.shape
Out[4]: torch.Size([3, 3])
# method-1
In [7]: list(ar.shape)
Out[7]: [3, 3]
# method-2
In [8]: [*ar.shape]
Out[8]: [3, 3]
# method-3
In [9]: [*ar.size()]
Out[9]: [3, 3]
Run Code Online (Sandbox Code Playgroud)
PS:请注意,这tensor.shape是的别名tensor.size(),尽管tensor.shape是所讨论的张量的属性,而是tensor.size()函数。