SHAP 函数在绘图方法中引发异常

use*_*279 9 python plot machine-learning shap

Samples.zip 示例压缩文件夹包含:

  1. 模型.pkl
  2. x_test.csv

要重现问题,请执行以下步骤:

  1. 用于lin2 =joblib.load('model.pkl')加载线性回归模型
  2. 用于x_test_2 = pd.read_csv('x_test.csv').drop(['Unnamed: 0'],axis=1)加载x_test_2
  3. 运行下面的代码来加载解释器
explainer_test = shap.Explainer(lin2.predict, x_test_2)
shap_values_test = explainer_test(x_test_2)
Run Code Online (Sandbox Code Playgroud)
  1. 然后运行partial_dependence_plot查看错误信息:

ValueError:x 和 y 不能大于二维,但具有形状 (2,) 和 (2, 1, 1)

sample_ind = 3
shap.partial_dependence_plot(
    "new_personal_projection_delta", 
    lin.predict, 
    x_test, model_expected_value=True,
    feature_expected_value=True, ice=False,
    shap_values=shap_values_test[sample_ind:sample_ind+1,:]
)
Run Code Online (Sandbox Code Playgroud)
  1. 运行另一个函数来绘制瀑布图以查看错误消息:

例外:waterfall_plot 需要模型输出的标量 base_values 作为第一个参数,但您已传递一个数组作为第一个参数!尝试 shap.waterfall_plot(explainer.base_values[0], value[0], X[0]) 或对于多输出模型尝试 shap.waterfall_plot(explainer.base_values[0], value[0][0], X[ 0])。

shap.plots.waterfall(shap_values_test[sample_ind], max_display=14)

问题:

  1. 为什么我不能运行partial_dependence_plot& shap.plots.waterfall
  2. 我需要对输入进行哪些更改才能运行上述方法?

Ser*_*nov 7

您需要正确构造Explanation新绘图 API 所需的对象SHAP

将执行以下操作:

import joblib
import shap
import warnings
warnings.filterwarnings("ignore")

model =joblib.load('model.pkl')
data = pd.read_csv('x_test.csv').drop(['Unnamed: 0'],axis=1)
explainer = shap.Explainer(model.predict, data)
sv = explainer(data)

idx = 3
exp = shap.Explanation(sv.values, sv.base_values[0][0], sv.data)
shap.plots.waterfall(exp[idx])
Run Code Online (Sandbox Code Playgroud)

在此输入图像描述

shap.partial_dependence_plot(
    "x7",
    model.predict,
    data,
    model_expected_value=True,
    feature_expected_value=True,
    ice=False,
    shap_values=exp
)
Run Code Online (Sandbox Code Playgroud)

在此输入图像描述