Scikit-learn 如何检查模型(例如 TfidfVectorizer)是否已经拟合

Cen*_*tAu 5 python numpy machine-learning scikit-learn

对于从文本中提取特征,如何检查向量化器(例如 TfIdfVectorizer 或 CountVectorizer)是否已经适合训练数据?
特别是,我希望代码能够自动确定矢量化器是否已经适合。

from sklearn.feature_extraction.text import TfidfVectorizer

vectorizer = TfidfVectorizer()

def vectorize_data(texts):
  # if vectorizer has not been already fit
  vectorizer.fit_transform(texts)
  # else
  vectorizer.transform(texts)
Run Code Online (Sandbox Code Playgroud)

Viv*_*mar 5

您可以使用check_is_fitted基本上用于执行此操作的 。

源代码中,TfidfVectorizer.transform()您可以检查其用法:

def transform(self, raw_documents, copy=True):

    # This is what you need.
    check_is_fitted(self, '_tfidf', 'The tfidf vector is not fitted')

    X = super(TfidfVectorizer, self).transform(raw_documents)
    return self._tfidf.transform(X, copy=False)
Run Code Online (Sandbox Code Playgroud)

所以在你的情况下,你可以这样做:

from sklearn.utils.validation import check_is_fitted

def vectorize_data(texts):

    try:
        check_is_fitted(vectorizer, '_tfidf', 'The tfidf vector is not fitted')
    except NotFittedError:
        vectorizer.fit(texts)

    # In all cases vectorizer if fit here, so just call transform()
    vectorizer.transform(texts)
Run Code Online (Sandbox Code Playgroud)