L X*_*dor 6 amazon-web-services scikit-learn boto3 xgboost amazon-sagemaker
我对我的 Sagemaker 模型感到非常困惑,它给了我奇怪的预测,有很多数字重复(20% 的预测是相同的)。为了排除故障,我决定在本地下载模型并进行测试,它返回了我所期望的结果。
所以我现在处于一种奇怪的情况,模型在 Sagemaker 中给出的预测与在本地的预测不同。
以下是我重现错误所采取的步骤(完整代码如下):
Run Code Online (Sandbox Code Playgroud)Booster object prediction: 0.9954053 Classifier object prediction: 0.9954053 Sagemaker endpoint prediction 0.693799495697
我将助推器对象转换为分类器只是为了确保不会影响任何东西。
怎么会发生这种事?它使用完全相同的模型工件和数据,我是否遗漏了有关加载模型如何工作的信息?根据我的理解,除了模型工件和输入数据之外,没有其他任何东西可以定义端点,它们在这里都是相同的......
import pandas as pd
import time
import boto3, sagemaker
import numpy as np
from sagemaker.predictor import csv_serializer
import xgboost as xgb
import tarfile
import os
import pickle
sess = sagemaker.Session()
sm_client = boto3.client('sagemaker')
endpoint_config_name = 'week-2-endpoint-config-prod' # using an existing endpoint config for demo
endpoint_name = 'week2-temp'
# Set up endpoint
#create_endpoint_response = sm_client.create_endpoint(
#EndpointName=endpoint_name,
#EndpointConfigName= endpoint_config_name)
#time.sleep(800)
# Get model name used by endpoint
model = sm_client.describe_endpoint_config(EndpointConfigName = endpoint_config_name)['ProductionVariants'][0]['ModelName']
# Get model artifact url
artifact_url = sm_client.describe_model(ModelName = model)['PrimaryContainer']['ModelDataUrl']
# Download model artifacts
dest = 'tempmodel.tar.gz'
boto3.client('s3').download_file(
Bucket = 'sagemaker-us-west-2-987938178880',
Key = artifact_url[38:],
Filename = dest
)
# Unpack and load the model
tf = tarfile.open(dest)
tf.extractall()
wk2_model = pickle.load(open("xgboost-model", "rb"))
os.remove(dest)
print('Model object loaded of type:', type(wk2_model))
# Set up SKlearn wrapper classifier
regr = xgb.XGBClassifier()
regr._Booster = wk2_model
# Set up endpoint
endpoint = sagemaker.predictor.RealTimePredictor(endpoint_name, sagemaker_session=sess)
endpoint.content_type = 'text/csv'
endpoint.serializer = csv_serializer
# Test data
feats = np.asarray([[2,162.0,0.21,1,18.0,0.0,0.0,0.0,3.33,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,31.0,31.0,38.0,38.0,0,0,1,0.35,-999.0,-999.0,-999.0,-999.0,-999,-999,-999,-999,24.0,55.0,56.0,30.0,62.0,64.0,755.0,1297.0,1466.0,7.0,11.0,13.0]])
print('')
print('Booster object prediction:', regr.predict_proba(feats)[:,1][0])
print('Classifier object prediction:', wk2_model.predict(xgb.DMatrix(feats))[0])
print('Sagemaker endpoint prediction', float(endpoint.predict(feats)))
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
1174 次 |
| 最近记录: |