car*_*lie 5 python scikit-learn
这可能是一个愚蠢的问题,但我找不到将目标标签与fetch_20newsgroups目标名称相匹配的方法。它是否像alt.atheism==一样明显1,这就是为什么我在任何地方都找不到它,或者是否有一种我只是找不到的匹配方法?
>>> from sklearn.datasets import fetch_20newsgroups
>>> newsgroups_train = fetch_20newsgroups(subset='train')
>>> from pprint import pprint
>>> pprint(list(newsgroups_train.target_names))
['alt.atheism',
'comp.graphics',
'comp.os.ms-windows.misc',
'comp.sys.ibm.pc.hardware',
'comp.sys.mac.hardware',
'comp.windows.x',
'misc.forsale',
'rec.autos',
'rec.motorcycles',
'rec.sport.baseball',
'rec.sport.hockey',
'sci.crypt',
'sci.electronics',
'sci.med',
'sci.space',
'soc.religion.christian',
'talk.politics.guns',
'talk.politics.mideast',
'talk.politics.misc',
'talk.religion.misc']
>>> newsgroups_train.target[:10]
array([12, 6, 9, 8, 6, 7, 9, 2, 13, 19])
Run Code Online (Sandbox Code Playgroud)
绝对不是一个愚蠢的问题,因为我也找不到任何相关文档。
fetch_20newsgroups我从这里看了一下函数的源代码。
def fetch_20newsgroups(data_home=None, subset='train', categories=None, # line#-149
shuffle=True, random_state=42,
remove=(),
download_if_missing=True, return_X_y=False):
"""Load the filenames and data from the 20 newsgroups dataset \
(classification).
Download it if necessary.
...
...
categories : None or collection of string or unicode # line#-177
If None (default), load all the categories.
If not None, list of category names to load (other categories
ignored).
...
...
"""
...
...
if categories is not None: # line#-287
labels = [(data.target_names.index(cat), cat) for cat in categories]
# Sort the categories to have the ordering of the labels
labels.sort()
labels, categories = zip(*labels)
mask = np.in1d(data.target, labels)
data.filenames = data.filenames[mask]
data.target = data.target[mask] # line#-294
# searchsorted to have continuous labels
data.target = np.searchsorted(labels, data.target)
data.target_names = list(categories)
# Use an object array to shuffle: avoids memory copy
data_lst = np.array(data.data, dtype=object)
data_lst = data_lst[mask]
data.data = data_lst.tolist()
...
...
return data
Run Code Online (Sandbox Code Playgroud)
请注意,其中一个参数是categories来自文档字符串的 AND,
如果无(默认),则加载所有类别。
如果不是“无”,则要加载的类别名称列表
所以categories所有的target_names.
现在,让我们转到源代码的#-287行。
您可以看到,当给出 时,它是根据中categories每个的索引进行排序的。 categorytarget_names
随后在第-294 行,target根据这些索引进行过滤。
这告诉我们,您获得的这些数字
实际上target是.target_names
因此,您可以通过 中的索引来匹配它们中的每一个target_names。
for idx, cat in enumerate(newsgroups_train.target_names):
print(idx, cat)
Run Code Online (Sandbox Code Playgroud)
0 alt.atheism
1 comp.graphics
2 comp.os.ms-windows.misc
3 comp.sys.ibm.pc.hardware
4 comp.sys.mac.hardware
5 comp.windows.x
6 misc.forsale
7 rec.autos
8 rec.motorcycles
9 rec.sport.baseball
10 rec.sport.hockey
11 sci.crypt
12 sci.electronics
13 sci.med
14 sci.space
15 soc.religion.christian
16 talk.politics.guns
17 talk.politics.mideast
18 talk.politics.misc
19 talk.religion.misc
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
837 次 |
| 最近记录: |