Ela*_*ine 1 python svm scikit-learn
我正在开发一个文本分类项目,并尝试使用 SVC(kernel= '线性') 来获取特征重要性。这是我的代码:(我更改了这篇文章
中的代码)
X = df1[features]
y = df1['label']
# Create selector class for text and numbers
class TextSelector(BaseEstimator, TransformerMixin):
"""Transformer to select a single column from the data frame to perform additional transformations on"""
def __init__(self, key):
self.key = key
def fit(self, X, y=None):
return self
def transform(self, X):
return X[self.key]
class NumberSelector(BaseEstimator, TransformerMixin):
"""For data grouped by feature, select subset of data at a provided key."""
def __init__(self, key):
self.key = key
def fit(self, X, y=None):
return self
def transform(self, X):
return X[[self.key]]
scaler = StandardScaler()
text = Pipeline([
('selector', TextSelector(key='title_mainText')),
('vect', TfidfVectorizer(ngram_range=(1, 2))),
])
upper_title = Pipeline([
('selector', NumberSelector(key='upper_title')),
('standard', scaler),
])
upper_mainText = Pipeline([
('selector', NumberSelector(key='upper_mainText')),
('standard', scaler),
])
punct_title = Pipeline([
('selector', NumberSelector(key='punct_title')),
('standard', scaler),
])
punct_mainText = Pipeline([
('selector', NumberSelector(key='punct_mainText')),
('standard', scaler),
])
exclamations_title = Pipeline([
('selector', NumberSelector(key='exclamations_title')),
('standard', scaler),
])
exclamations_text = Pipeline([
('selector', NumberSelector(key='exclamations_text')),
('standard', scaler),
])
feats = FeatureUnion([('title_mainText', text),
('upper_title', upper_title),
('upper_mainText', upper_mainText),
('punct_title', punct_title),
('punct_mainText', punct_mainText),
('exclamations_text', exclamations_text),
('exclamations_title', exclamations_title),
feature_processing = Pipeline([('feats', feats)])
pipeline = Pipeline([
('features', feats),
('classifier', SVC(C=1, kernel= 'linear', max_iter= 1000, tol=0.0001, probability=True))
])
def f_importances(coef, names):
imp = coef
imp,names = zip(*sorted(zip(imp,names)))
plt.barh(range(len(names)), imp, align='center')
plt.yticks(range(len(names)), names)
plt.show()
features_names = ['title_mainText', 'upper_title', 'upper_mainText', 'punct_title', 'punct_mainText',
'exclamations_title', 'exclamations_text']
pipeline.fit(X, y)
clf = pipeline.named_steps['classifier']
f_importances(clf.coef_, features_names)
Run Code Online (Sandbox Code Playgroud)
但是,它显示错误消息,我不知道我哪里做错了。以前有人有过这方面的经验吗?
() 中的 ValueError Traceback(最近一次调用) 13 pipeline.fit(X, y) 14 clf = pipeline.named_steps['classifier'] ---> 15 f_importances((clf.coef_[0]), features_names) 16
在 f_importances(coef, 名称) 5 imp = coef 6 imp,names = zip(*sorted(zip(imp,names))) ----> 7 plt.barh(range(len(names)), imp, 对齐='center') 8 plt.yticks(range(len(names)), 名称) 9 plt.show()
/anaconda3/lib/python3.6/site-packages/matplotlib/pyplot.py in barh(*args, **kwargs) 2667 mplDeprecation)
2668 尝试: -> 2669 ret = ax.barh(*args, **kwargs)第2670章 第2671章/anaconda3/lib/python3.6/site-packages/matplotlib/axes/_axes.py in barh(self, *args, **kwargs) 2281
kwargs.setdefault('orientation', 'horizontal') 2282 patch = self. bar(x=left, height=height, width=width, -> 2283 Bottom=y, **kwargs) 2284 返回补丁 2285/anaconda3/lib/python3.6/site-packages/matplotlib/ init .py in inner(ax, *args, **kwargs) 1715
warnings.warn(msg % (label_namer, func.name ) , 1716
RuntimeWarning, stacklevel= 2) -> 第1717章 返回 func(ax, *args, **kwargs) 第1718章 pre_doc = 内部。doc 1719 如果 pre_doc 为 None:/anaconda3/lib/python3.6/site-packages/matplotlib/axes/_axes.py in bar(self, *args, **kwargs) 2091 elif 方向 == '水平': 2092 r.sticky_edges.x.append( l) -> 2093 self.add_patch(r) 2094 补丁.append(r) 2095
/anaconda3/lib/python3.6/site-packages/matplotlib/axes/_base.py 在 add_patch(self, p) 1852 如果 p.get_clip_path() 为 None:
1853 p.set_clip_path(self.patch) -> 1854 self第1855章 第1856章 1856/anaconda3/lib/python3.6/site-packages/matplotlib/axes/_base.py in _update_patch_limits(self, patch) 1868 # 或高度。第1869章 第1870章 第1871章 第1872
章/anaconda3/lib/python3.6/site-packages/scipy/sparse/base.py in bool (self) 286 return self.nnz != 0 287 else: --> 288 raise ValueError("数组的真值第289章 290、 第290章
ValueError:具有多个元素的数组的真值不明确。使用 a.any() 或 a.all()。
谢谢你!
Scikit-Learn 的文档指出,coef_ 属性是形状为 shape = [n_class * (n_class-1) / 2, n_features] 的数组。假设有 4 个类和 9 个特征,_coef 的形状为 6 x 9(六行九列)。另一方面,barh 期望每个特征有一个值而不是六个,因此您会收到错误。如果将每列的系数相加,就可以消除它,如下例所示。
import numpy as np
import matplotlib.pyplot as plt
def f_importances(coef, names):
imp = coef
imp,names = zip(*sorted(zip(imp,names)))
plt.barh(range(len(names)), imp, align='center')
plt.yticks(range(len(names)), names)
plt.show()
features_names = ['title_mainText', 'upper_title', 'upper_mainText', 'punct_title', 'punct_mainText',
'exclamations_title', 'exclamations_text', 'title_words_not_stopword', 'text_words_not_stopword']
n_classes = 4
n_features = len(features_names)
clf_coef_ = np.random.randint(1, 30, size=(int(0.5*n_classes*(n_classes-1)), n_features))
f_importances(clf_coef_.sum(axis=0), features_names)
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
2125 次 |
最近记录: |