对 scikit-learn 模型的预测是线程安全的吗?

Tob*_*ann 9 python thread-safety scikit-learn

给定一些分类器(SVC/Forest/NN/任何).predict从不同线程并发调用同一个实例是否安全?

从遥远的角度来看,我的猜测是它们不会改变任何内部状态。但我没有在文档中找到任何关于它的内容。

这是一个最小的例子,显示了我的意思:

#!/usr/bin/env python3
import threading

from sklearn import datasets
from sklearn import svm
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier

X, y = datasets.load_iris(return_X_y=True)

# Some model. Might be any type, e.g.:
clf = svm.SVC()
clf = RandomForestClassifier(),
clf = MLPClassifier(solver='lbfgs')

clf.fit(X, y)


def use_model_for_predictions():
    for _ in range(10000):
        clf.predict(X[0:1])


# Is this safe?
thread_1 = threading.Thread(target=use_model_for_predictions)
thread_2 = threading.Thread(target=use_model_for_predictions)
thread_1.start()
thread_2.start()
Run Code Online (Sandbox Code Playgroud)

Rul*_*uli 1

查看此问答predictpredict_proba方法应该是线程安全的,因为它们只调用 NumPy,它们在任何情况下都不会影响模型本身,所以您的问题的答案是肯定的。

您还可以在此处的回复中找到一些信息。

例如,在朴素贝叶斯中,代码如下:

def predict(self, X):
    """
    Perform classification on an array of test vectors X.
    Parameters
    ----------
    X : array-like of shape (n_samples, n_features)
    Returns
    -------
    C : ndarray of shape (n_samples,)
        Predicted target values for X
    """
    check_is_fitted(self)
    X = self._check_X(X)
    jll = self._joint_log_likelihood(X)
    return self.classes_[np.argmax(jll, axis=1)]
Run Code Online (Sandbox Code Playgroud)

您可以看到前两行只是检查输入。抽象方法_joint_log_likelihood是我们感兴趣的一种,描述为:

@abstractmethod
def _joint_log_likelihood(self, X):
    """Compute the unnormalized posterior log probability of X
    I.e. ``log P(c) + log P(x|c)`` for all rows x of X, as an array-like of
    shape (n_classes, n_samples).
    Input is passed to _joint_log_likelihood as-is by predict,
    predict_proba and predict_log_proba.
    """
Run Code Online (Sandbox Code Playgroud)

最后,例如对于多项式 NB,函数如下所示(来源):

def _joint_log_likelihood(self, X):
    """
    Compute the unnormalized posterior log probability of X, which is
    the features' joint log probability (feature log probability times
    the number of times that word appeared in that document) times the
    class prior (since we're working in log space, it becomes an addition)
    """
    joint_prob = X * self.feature_log_prob_.T + self.class_log_prior_
    return joint_prob
Run Code Online (Sandbox Code Playgroud)

可以看到,没有任何线程不安全的地方predict。当然,您可以浏览代码并检查其中任何一个分类器:)