有没有办法以矢量化方式(使用pandas_udf?)在pyspark数据帧上运行pytorch模型的推理。
一行 udf 非常慢,因为需要为每一行加载模型 state_dict()。我正在尝试使用 pandas_udf 来加快速度,因为所有操作都可以在 pandas/pytorch 中有效地矢量化。
我已经查看了这个 databricks 帖子以获得灵感,但它与我的用例并不完全对应,因为我想对现有的 pyspark 数据框运行预测。
在这个简单的例子中,我可以使用一行 udf 让它工作:
import torch
import torch.nn as nn
from pyspark.sql.functions import col, pandas_udf, PandasUDFType, udf
import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from pyspark.sql.types import ArrayType, FloatType, DoubleType
import pandas as pd
import numpy as np
spark = SparkSession.builder.master('local[*]') \
.appName("model_training") \
.getOrCreate()
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.w = nn.Linear(5, 1)
def forward(self, x):
return self.w(x)
net = Net()
bc_model_state = spark.sparkContext.broadcast(net.state_dict()) …Run Code Online (Sandbox Code Playgroud)