El *_*ude 5 numpy pytorch numpy-ndarray
将 PyTorch 张量转换为 NumPy 我得到
print(nn_result.shape)
# (2433, 2)
np_result = torch.argmax(nn_result).numpy()
type(np_result)
# <type 'numpy.ndarray'>
print(len(np_result))
TypeError: len() of unsized object
Run Code Online (Sandbox Code Playgroud)
为什么?我认为根据文档该numpy()函数会返回一个正确的ndarray,但它似乎不完整?
也许您想使用torch.argmax(nn_result, dim=1)? 由于dim默认为 0,因此它仅返回构造为张量的单个数字。让我用下面的例子来说明:
>>> x = np.array(1)
>>> x.shape
()
>>> len(x)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: len() of unsized object
>>> x = np.array([1])
>>> x.shape
(1,)
>>> len(x)
1
Run Code Online (Sandbox Code Playgroud)
基本上np.array将采用object您构建的任何类型。在第一种情况下,对象不是数组,因此您看不到有效的形状。由于它不是数组,因此调用len会引发错误。
torch.argmaxwithdim=0返回一个张量,如上例第一种情况所示,因此会出现错误。
| 归档时间: |
|
| 查看次数: |
1335 次 |
| 最近记录: |