SMM*_*iSP 5 python numpy scikit-learn pytorch tensor
在使用 scikit-learn 时,我可以使用 PyTorch 张量代替 NumPy 数组吗?
我尝试了 scikit-learn 中的一些方法,例如train_test_split和StandardScalar,它似乎工作得很好,但是当我使用 PyTorch 张量而不是 NumPy 数组时,有什么我应该知道的吗?
numpy 数组或 scipy 稀疏矩阵。其他可转换为数值数组的类型(例如 pandas DataFrame)也是可接受的。
这是否意味着使用 PyTorch 张量是完全安全的?
我不认为 scikit-learn 直接支持 PyTorch 张量。但您始终可以从 PyTorch 张量获取底层 numpy 数组
my_nparray = my_tensor.numpy()
Run Code Online (Sandbox Code Playgroud)
然后将其与 scikit learn 函数一起使用。