将目标与 fetch_20newsgroups 中的目标名称相匹配

car*_*lie 5 python scikit-learn

这可能是一个愚蠢的问题,但我找不到将目标标签与fetch_20newsgroups目标名称相匹配的方法。它是否像alt.atheism==一样明显1,这就是为什么我在任何地方都找不到它,或者是否有一种我只是找不到的匹配方法?

>>> from sklearn.datasets import fetch_20newsgroups
>>> newsgroups_train = fetch_20newsgroups(subset='train')

>>> from pprint import pprint
>>> pprint(list(newsgroups_train.target_names))
['alt.atheism',
 'comp.graphics',
 'comp.os.ms-windows.misc',
 'comp.sys.ibm.pc.hardware',
 'comp.sys.mac.hardware',
 'comp.windows.x',
 'misc.forsale',
 'rec.autos',
 'rec.motorcycles',
 'rec.sport.baseball',
 'rec.sport.hockey',
 'sci.crypt',
 'sci.electronics',
 'sci.med',
 'sci.space',
 'soc.religion.christian',
 'talk.politics.guns',
 'talk.politics.mideast',
 'talk.politics.misc',
 'talk.religion.misc']
>>> newsgroups_train.target[:10]
array([12,  6,  9,  8,  6,  7,  9,  2, 13, 19])
Run Code Online (Sandbox Code Playgroud)

ywb*_*aek 4

绝对不是一个愚蠢的问题,因为我也找不到任何相关文档。

fetch_20newsgroups我从这里看了一下函数的源代码。

def fetch_20newsgroups(data_home=None, subset='train', categories=None,  # line#-149
                       shuffle=True, random_state=42,
                       remove=(),
                       download_if_missing=True, return_X_y=False):
    """Load the filenames and data from the 20 newsgroups dataset \
(classification).
    Download it if necessary.
...
...
    categories : None or collection of string or unicode                 # line#-177
        If None (default), load all the categories.
        If not None, list of category names to load (other categories
        ignored).
...
...
    """
...
...
    if categories is not None:                                           # line#-287
        labels = [(data.target_names.index(cat), cat) for cat in categories]
        # Sort the categories to have the ordering of the labels
        labels.sort()
        labels, categories = zip(*labels)
        mask = np.in1d(data.target, labels)
        data.filenames = data.filenames[mask]
        data.target = data.target[mask]                                  # line#-294
        # searchsorted to have continuous labels
        data.target = np.searchsorted(labels, data.target)
        data.target_names = list(categories)
        # Use an object array to shuffle: avoids memory copy
        data_lst = np.array(data.data, dtype=object)
        data_lst = data_lst[mask]
        data.data = data_lst.tolist()
...
...
    return data
Run Code Online (Sandbox Code Playgroud)

请注意,其中一个参数是categories来自文档字符串的 AND,

如果无(默认),则加载所有类别。
如果不是“无”,则要加载的类别名称列表

所以categories所有的target_names.

现在,让我们转到源代码的#-287行。
您可以看到,当给出 时,它是根据中categories每个的索引进行排序的。 categorytarget_names

随后在第-294 行,target根据这些索引进行过滤。
这告诉我们,您获得的这些数字 实际上target是.
target_names

因此,您可以通过 中的索引来匹配它们中的每一个target_names

for idx, cat in enumerate(newsgroups_train.target_names):
    print(idx, cat)
Run Code Online (Sandbox Code Playgroud)
0 alt.atheism
1 comp.graphics
2 comp.os.ms-windows.misc
3 comp.sys.ibm.pc.hardware
4 comp.sys.mac.hardware
5 comp.windows.x
6 misc.forsale
7 rec.autos
8 rec.motorcycles
9 rec.sport.baseball
10 rec.sport.hockey
11 sci.crypt
12 sci.electronics
13 sci.med
14 sci.space
15 soc.religion.christian
16 talk.politics.guns
17 talk.politics.mideast
18 talk.politics.misc
19 talk.religion.misc
Run Code Online (Sandbox Code Playgroud)