为什么在 MNIST 分类器代码中使用 X[0] 会给我一个错误?

Arj*_*wal 2 machine-learning mnist

我正在学习使用 MNIST 数据集进行分类。我遇到了一个错误,我无法弄清楚,我已经做了很多谷歌搜索,但我什么也做不了,也许你是专家,可以帮助我。这是代码——

>>> from sklearn.datasets import fetch_openml
>>> mnist = fetch_openml('mnist_784', version=1)
>>> mnist.keys()
Run Code Online (Sandbox Code Playgroud)

输出:dict_keys(['data', 'target', 'frame', 'categories', 'feature_names', 'target_names', 'DESCR', 'details', 'url'])

>>> X, y = mnist["data"], mnist["target"]
>>> X.shape
Run Code Online (Sandbox Code Playgroud)

输出:(70000, 784)

>>> y.shape
Run Code Online (Sandbox Code Playgroud)

输出:(70000)

>>> X[0]

output:KeyError                                  Traceback (most recent call last)
c:\users\khush\appdata\local\programs\python\python39\lib\site-packages\pandas\core\indexes\base.py in get_loc(self, key, method, tolerance)
   2897             try:
-> 2898                 return self._engine.get_loc(casted_key)
   2899             except KeyError as err:

pandas\_libs\index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas\_libs\index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas\_libs\hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()

pandas\_libs\hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()

KeyError: 0

The above exception was the direct cause of the following exception:

KeyError                                  Traceback (most recent call last)
<ipython-input-10-19c40ecbd036> in <module>
----> 1 X[0]

c:\users\khush\appdata\local\programs\python\python39\lib\site-packages\pandas\core\frame.py in __getitem__(self, key)
   2904             if self.columns.nlevels > 1:
   2905                 return self._getitem_multilevel(key)
-> 2906             indexer = self.columns.get_loc(key)
   2907             if is_integer(indexer):
   2908                 indexer = [indexer]

c:\users\khush\appdata\local\programs\python\python39\lib\site-packages\pandas\core\indexes\base.py in get_loc(self, key, method, tolerance)
   2898                 return self._engine.get_loc(casted_key)
   2899             except KeyError as err:
-> 2900                 raise KeyError(key) from err
   2901 
   2902         if tolerance is not None:

KeyError: 0
Run Code Online (Sandbox Code Playgroud)

请回答,可能有一个愚蠢的错误,因为我是 ML 的初学者。如果您也给我一些提示,那将非常有帮助。

Pro*_*oko 14

APIfetch_openml在版本之间发生了变化。在早期版本中,它返回一个numpy.ndarray数组。自0.24.0(2020 年 12 月)以来,as_frame参数fetch_openml被设置为auto(而不是False之前的默认选项),这pandas.DataFrame为您提供了 MNIST 数据。numpy.ndarray您可以通过设置强制将数据读取为 a as_frame = False。请参阅fetch_openml 参考


小智 10

如果您按照以下代码操作,则无需降级 scikit-learn 库:

from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version= 1, as_frame= False)
mnist.keys()
Run Code Online (Sandbox Code Playgroud)


das*_*sum 7

我也面临同样的问题。

  • scikit 学习:0.24.0
  • matplotlib:3.3.3
  • 蟒蛇:3.9.1

我曾经用下面的代码来解决这个问题。

import matplotlib as mpl
import matplotlib.pyplot as plt


# instead of some_digit = X[0]
some_digit = X.to_numpy()[0]
some_digit_image = some_digit.reshape(28,28)

plt.imshow(some_digit_image,cmap="binary")
plt.axis("off")
plt.show()
Run Code Online (Sandbox Code Playgroud)