小编Ham*_*mza的帖子

如何使用 pandas_udf 在 pyspark 数据帧上运行 pytorch 模型的推理(创建带有预测的新列)?

有没有办法以矢量化方式(使用pandas_udf?)在pyspark数据帧上运行pytorch模型的推理。

一行 udf 非常慢,因为需要为每一行加载模型 state_dict()。我正在尝试使用 pandas_udf 来加快速度,因为所有操作都可以在 pandas/pytorch 中有效地矢量化。

我已经查看了这个 databricks 帖子以获得灵感,但它与我的用例并不完全对应,因为我想对现有的 pyspark 数据框运行预测。

在这个简单的例子中,我可以使用一行 udf 让它工作:

import torch
import torch.nn as nn
from pyspark.sql.functions import col, pandas_udf, PandasUDFType, udf
import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from pyspark.sql.types import ArrayType, FloatType, DoubleType
import pandas as pd
import numpy as np

spark = SparkSession.builder.master('local[*]') \
    .appName("model_training") \
    .getOrCreate()

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.w = nn.Linear(5, 1)

    def forward(self, x):
        return self.w(x)

net = Net()
bc_model_state = spark.sparkContext.broadcast(net.state_dict()) …
Run Code Online (Sandbox Code Playgroud)

pandas apache-spark apache-spark-sql pyspark pytorch

8
推荐指数
1
解决办法
3212
查看次数