无法从pyspark RDD的map方法访问类方法

Dev*_*eda 2 python rdd pyspark

在我的应用程序的代码库中集成pyspark时,我无法在RDD的map方法中引用类的方法.我用一个简单的例子重复了这个问题,如下所示

这是一个虚拟类,我已经定义了它只是为RDD的每个元素添加一个数字,RDD是一个类属性:

class Test:

    def __init__(self):
        self.sc = SparkContext()
        a = [('a', 1), ('b', 2), ('c', 3)]
        self.a_r = self.sc.parallelize(a)

    def add(self, a, b):
        return a + b

    def test_func(self, b):
        c_r = self.a_r.map(lambda l: (l[0], l[1] * 2))
        v = c_r.map(lambda l: self.add(l[1], b))
        v_c = v.collect()
        return v_c
Run Code Online (Sandbox Code Playgroud)

test_func()map()在RDD上调用方法v,然后在add()每个元素上调用方法v.调用test_func()抛出以下错误:

pickle.PicklingError: Could not serialize object: Exception: It appears that you are attempting to reference SparkContext from a broadcast variable, action, or transformation. SparkContext can only be used on the driver, not in code that it run on workers. For more information, see SPARK-5063.
Run Code Online (Sandbox Code Playgroud)

现在,当我将该add()方法移出类时:

def add(self, a, b):
    return a + b

class Test:

    def __init__(self):
        self.sc = SparkContext()
        a = [('a', 1), ('b', 2), ('c', 3)]
        self.a_r = self.sc.parallelize(a)

    def test_func(self, b):

        c_r = self.a_r.map(lambda l: (l[0], l[1] * 2))
        v = c_r.map(lambda l: add(l[1], b))
        v_c = v.collect()

        return v_c
Run Code Online (Sandbox Code Playgroud)

呼叫test_func()现在正常工作.

[7, 9, 11]
Run Code Online (Sandbox Code Playgroud)

为什么会发生这种情况?如何将类方法传递给RDD的map()方法?

tom*_*mas 6

发生这种情况是因为当pyspark尝试序列化您的函数时,它还需要序列化您的Test类的实例(因为您传递给的函数map具有对此实例的引用self).此实例引用了spark上下文.你需要确保SparkContextRDDs的不是由被序列化并送到工人的任何对象引用.SparkContext需要只在司机身上生活.

这应该工作:

在档案中testspark.py:

class Test(object):
    def add(self, a, b):
        return a + b

    def test_func(self, a_r, b):
        c_r = a_r.map(lambda l: (l[0], l[1] * 2))
        # now `self` has no reference to the SparkContext()
        v = c_r.map(lambda l: self.add(l[1], b)) 
        v_c = v.collect()
        return v_c
Run Code Online (Sandbox Code Playgroud)

在您的主脚本中:

from pyspark import SparkContext
from testspark import Test

sc = SparkContext()
a = [('a', 1), ('b', 2), ('c', 3)]
a_r = sc.parallelize(a)

test = Test()
test.test_func(a_r, 5) # should give [7, 9, 11]
Run Code Online (Sandbox Code Playgroud)