将 PyTorch 张量与 scikit-learn 结合使用

SMM*_*iSP 5 python numpy scikit-learn pytorch tensor

在使用 scikit-learn 时,我可以使用 PyTorch 张量代替 NumPy 数组吗?

我尝试了 scikit-learn 中的一些方法,例如train_test_splitStandardScalar,它似乎工作得很好,但是当我使用 PyTorch 张量而不是 NumPy 数组时,有什么我应该知道的吗?

根据https://scikit-learn.org/stable/faq.html#how-can-i-load-my-own-datasets-into-a-format-usable-by-scikit-learn上的这个问题:

numpy 数组或 scipy 稀疏矩阵。其他可转换为数值数组的类型(例如 pandas DataFrame)也是可接受的。

这是否意味着使用 PyTorch 张量是完全安全的?

Aya*_*Das 7

我不认为 scikit-learn 直接支持 PyTorch 张量。但您始终可以从 PyTorch 张量获取底层 numpy 数组

my_nparray = my_tensor.numpy()
Run Code Online (Sandbox Code Playgroud)

然后将其与 scikit learn 函数一起使用。