相关疑难解决方法(0)

在pyspark UDF中使用tensorflow.keras模型会产生pickle错误

我想在 pysark pandas_udf 中使用 tensorflow.keras 模型。但是,在将模型发送给工作人员之前对其进行序列化时,我遇到了 pickle 错误。我不确定我是否使用最好的方法来执行我想要的操作,因此我将公开一个最小但完整的示例。

套餐:

  • tensorflow-2.2.0(但所有以前的版本也会触发错误)
  • pyspark-2.4.5

进口声明是:

import pandas as pd
import numpy as np

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

from pyspark.sql import SparkSession, functions as F, types as T
Run Code Online (Sandbox Code Playgroud)

Pyspark UDF 是 pandas_udf:

def compute_output_pandas_udf(model):
    '''Spark pandas udf for model prediction.'''

    @F.pandas_udf(T.DoubleType(), F.PandasUDFType.SCALAR)
    def compute_output(inputs1, inputs2, inputs3):
        pdf = pd.DataFrame({
            'input1': inputs1,
            'input2': inputs2,
            'input3': inputs3
        })
        pdf['predicted_output'] = model.predict(pdf.values)
        return pdf['predicted_output']

    return compute_output
Run Code Online (Sandbox Code Playgroud)

主要代码:

# Model parameters
weights = np.array([[0.5], …
Run Code Online (Sandbox Code Playgroud)

user-defined-functions apache-spark pyspark keras tensorflow

4
推荐指数
1
解决办法
6419
查看次数