使用python序列化自定义转换器以在Pyspark ML管道中使用

Tec*_*ent 10 pyspark apache-spark-ml

在PySpark ML创建自定义Transformer的评论部分找到了相同的讨论,但没有明确的答案.还有一个未解决的JIRA对应于:https://issues.apache.org/jira/browse/SPARK-17025.

鉴于Pyspark ML管道没有提供用于保存用python编写的自定义转换器的选项,有什么其他选项可以完成它?如何在我的python类中实现返回兼容java对象的_to_java方法?

Ben*_*nns 16

火花2.3.0有一个很多,很多更好的方法来做到这一点.

简单地扩展DefaultParamsWritable并且DefaultParamsReadable您的类将自动拥有writeread方法来保存您的参数并将被PipelineModel序列化系统使用.

文档不是很清楚,我不得不做一些源阅读,以了解这是反序列化的工作方式.

  • PipelineModel.read 实例化一个 PipelineModelReader
  • PipelineModelReader加载元数据并检查语言是否正确'Python'.如果不是,则使用典型JavaMLReader(这些答案中的大多数都是针对的)
  • 否则,PipelineSharedReadWrite使用,调用DefaultParamsReader.loadParamsInstance

loadParamsInstance将从class保存的元数据中找到.它将实例化该类并调用.load(path)它.您可以自动扩展DefaultParamsReader并获取DefaultParamsReader.load方法.如果您确实需要实现专门的反序列化逻辑,我会将该load方法视为起始位置.

在另一边:

  • PipelineModel.write将检查所有阶段是否为Java(实现JavaMLWritable).如果是这样,JavaMLWriter则使用典型(这些答案中的大部分都是为此而设计的)
  • 否则,PipelineWriter使用,检查所有阶段是否实现MLWritable和调用PipelineSharedReadWrite.saveImpl
  • PipelineSharedReadWrite.saveImpl将呼吁.write().save(path)每个阶段.

您可以扩展DefaultParamsWriter以获取DefaultParamsWritable.write以正确格式保存类和params的元数据的方法.如果您需要实现自定义序列化逻辑,我会将其DefaultParamsWriter视为一个起点.

好吧,最后,你有一个非常简单的变换器,它扩展了Params,你的所有参数都以典型的Params方式存储:

from pyspark import keyword_only
from pyspark.ml import Transformer
from pyspark.ml.param.shared import HasOutputCols, Param, Params
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
from pyspark.sql.functions import lit # for the dummy _transform

class SetValueTransformer(
    Transformer, HasOutputCols, DefaultParamsReadable, DefaultParamsWritable,
):
    value = Param(
        Params._dummy(),
        "value",
        "value to fill",
    )

    @keyword_only
    def __init__(self, outputCols=None, value=0.0):
        super(SetValueTransformer, self).__init__()
        self._setDefault(value=0.0)
        kwargs = self._input_kwargs
        self._set(**kwargs)

    @keyword_only
    def setParams(self, outputCols=None, value=0.0):
        """
        setParams(self, outputCols=None, value=0.0)
        Sets params for this SetValueTransformer.
        """
        kwargs = self._input_kwargs
        return self._set(**kwargs)

    def setValue(self, value):
        """
        Sets the value of :py:attr:`value`.
        """
        return self._set(value=value)

    def getValue(self):
        """
        Gets the value of :py:attr:`value` or its default value.
        """
        return self.getOrDefault(self.value)

    def _transform(self, dataset):
        for col in self.getOutputCols():
            dataset = dataset.withColumn(col, lit(self.getValue()))
        return dataset
Run Code Online (Sandbox Code Playgroud)

现在我们可以使用它:

from pyspark.ml import Pipeline

svt = SetValueTransformer(outputCols=["a", "b"], value=123.0)

p = Pipeline(stages=[svt])
df = sc.parallelize([(1, None), (2, 1.0), (3, 0.5)]).toDF(["key", "value"])
pm = p.fit(df)
pm.transform(df).show()
pm.write().overwrite().save('/tmp/example_pyspark_pipeline')
pm2 = PipelineModel.load('/tmp/example_pyspark_pipeline')
print('matches?', pm2.stages[0].extractParamMap() == pm.stages[0].extractParamMap())
pm2.transform(df).show()
Run Code Online (Sandbox Code Playgroud)

结果:

+---+-----+-----+-----+
|key|value|    a|    b|
+---+-----+-----+-----+
|  1| null|123.0|123.0|
|  2|  1.0|123.0|123.0|
|  3|  0.5|123.0|123.0|
+---+-----+-----+-----+

matches? True
+---+-----+-----+-----+
|key|value|    a|    b|
+---+-----+-----+-----+
|  1| null|123.0|123.0|
|  2|  1.0|123.0|123.0|
|  3|  0.5|123.0|123.0|
+---+-----+-----+-----+
Run Code Online (Sandbox Code Playgroud)

  • 对于那些在执行 [`Pipeline.load(filename)`](https://spark.apache.org/docs/2.2.0/api/) 时收到错误`AttributeError: module '__main__' has no attribute 'YourTransformerClass'` python/pyspark.ml.html#pyspark.ml.PipelineModel.load),尝试通过执行 `m = __import__("__main__") 在当前的 `__main__` 模块中注册 `YourTransformerClass`;setattr(m, 'YourTransformerClass', YourTransformerClass)`。查看 [`DefaultParamsReader`](https://spark.apache.org/docs/latest/api/python/pyspark.ml.html#pyspark.ml.DefaultParamsReader) 的源代码。 (3认同)

dmb*_*ker 9

我不确定这是最好的方法,但我也需要能够保存我在Pyspark中创建的自定义Estimators,Transformers和Models,并且还支持它们在持久性管道API中的使用.可以在Pipeline API中创建和使用自定义Pyspark估算器,变换器和模型,但无法保存.当模型训练花费比事件预测周期更长时,这在生产中造成问题.

一般来说,Pyspark Estimators,Transformers和Models只是围绕Java或Scala等价物的包装器,而Pyspark包装器只是通过py4j编组来自Java的参数.然后,在Java端完成模型的任何持久化.由于目前的这种结构,这限制了Custom Pyspark估算器,变形金刚和模型只能在python世界中生活.

在之前的尝试中,我能够通过使用Pickle/dill序列化来保存单个Pyspark模型.这很好用,但仍然不允许在Pipeline API中保存或加载.但是,另一个SO帖子指出我被引导到OneVsRest分类器,并检查了_to_java和_from_java方法.他们在Pyspark一侧做了所有繁重的工作.看了之后我想,如果有人能够将pickle转储保存到已经制作并支持的可保存java对象,那么应该可以使用Pipeline API保存Custom Pyspark Estimator,Transformer和Model.

为此,我发现StopWordsRemover是劫持的理想对象,因为它有一个属性,即停用词,即字符串列表.dill.dumps方法以字符串形式返回对象的pickle表示.计划是将字符串转换为列表,然后将StopWordsRemover的stopwords参数设置为此列表.虽然列表字符串,我发现一些字符不会编组到java对象.所以字符转换为整数然后整数转换为字符串.这一切都非常适合保存单个实例,也适用于在管道中保存,因为管道尽职尽责地调用我的python类的_to_java方法(我们仍然在Pyspark方面这样工作).但是,从Java回到Pyspark并没有在Pipeline API中.

因为我将我的python对象隐藏在StopWordsRemover实例中,所以当回到Pyspark时,Pipeline对我隐藏的类对象一无所知,它只知道它有一个StopWordsRemover实例.理想情况下,继承Pipeline和PipelineModel会很棒,但是这会让我们回到尝试序列化Python对象.为了解决这个问题,我创建了一个PysparkPipelineWrapper,它接受一个Pipeline或PipelineModel并只扫描各个阶段,在stopwords列表中查找一个编码ID(记住,这只是我的python对象的pickled字节),告诉它打开列表到我的实例并将其存储在它来自的阶段.下面的代码显示了这一切是如何工作的.

对于任何Custom Pyspark Estimator,Transformer和Model,只需继承自Identifiable,PysparkReaderWriter,MLReadable,MLWritable.然后在加载Pipeline和PipelineModel时,通过PysparkPipelineWrapper.unwrap(管道)传递.

此方法不涉及在Java或Scala中使用Pyspark代码,但至少我们可以保存和加载Custom Pyspark Estimators,Transformers和Models并使用Pipeline API.

import dill
from pyspark.ml import Transformer, Pipeline, PipelineModel
from pyspark.ml.param import Param, Params
from pyspark.ml.util import Identifiable, MLReadable, MLWritable, JavaMLReader, JavaMLWriter
from pyspark.ml.feature import StopWordsRemover
from pyspark.ml.wrapper import JavaParams
from pyspark.context import SparkContext
from pyspark.sql import Row

class PysparkObjId(object):
    """
    A class to specify constants used to idenify and setup python 
    Estimators, Transformers and Models so they can be serialized on there
    own and from within a Pipline or PipelineModel.
    """
    def __init__(self):
        super(PysparkObjId, self).__init__()

    @staticmethod
    def _getPyObjId():
        return '4c1740b00d3c4ff6806a1402321572cb'

    @staticmethod
    def _getCarrierClass(javaName=False):
        return 'org.apache.spark.ml.feature.StopWordsRemover' if javaName else StopWordsRemover

class PysparkPipelineWrapper(object):
    """
    A class to facilitate converting the stages of a Pipeline or PipelineModel
    that were saved from PysparkReaderWriter.
    """
    def __init__(self):
        super(PysparkPipelineWrapper, self).__init__()

    @staticmethod
    def unwrap(pipeline):
        if not (isinstance(pipeline, Pipeline) or isinstance(pipeline, PipelineModel)):
            raise TypeError("Cannot recognize a pipeline of type %s." % type(pipeline))

        stages = pipeline.getStages() if isinstance(pipeline, Pipeline) else pipeline.stages
        for i, stage in enumerate(stages):
            if (isinstance(stage, Pipeline) or isinstance(stage, PipelineModel)):
                stages[i] = PysparkPipelineWrapper.unwrap(stage)
            if isinstance(stage, PysparkObjId._getCarrierClass()) and stage.getStopWords()[-1] == PysparkObjId._getPyObjId():
                swords = stage.getStopWords()[:-1] # strip the id
                lst = [chr(int(d)) for d in swords]
                dmp = ''.join(lst)
                py_obj = dill.loads(dmp)
                stages[i] = py_obj

        if isinstance(pipeline, Pipeline):
            pipeline.setStages(stages)
        else:
            pipeline.stages = stages
        return pipeline

class PysparkReaderWriter(object):
    """
    A mixin class so custom pyspark Estimators, Transformers and Models may
    support saving and loading directly or be saved within a Pipline or PipelineModel.
    """
    def __init__(self):
        super(PysparkReaderWriter, self).__init__()

    def write(self):
        """Returns an MLWriter instance for this ML instance."""
        return JavaMLWriter(self)

    @classmethod
    def read(cls):
        """Returns an MLReader instance for our clarrier class."""
        return JavaMLReader(PysparkObjId._getCarrierClass())

    @classmethod
    def load(cls, path):
        """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
        swr_java_obj = cls.read().load(path)
        return cls._from_java(swr_java_obj)

    @classmethod
    def _from_java(cls, java_obj):
        """
        Get the dumby the stopwords that are the characters of the dills dump plus our guid
        and convert, via dill, back to our python instance.
        """
        swords = java_obj.getStopWords()[:-1] # strip the id
        lst = [chr(int(d)) for d in swords] # convert from string integer list to bytes
        dmp = ''.join(lst)
        py_obj = dill.loads(dmp)
        return py_obj

    def _to_java(self):
        """
        Convert this instance to a dill dump, then to a list of strings with the unicode integer values of each character.
        Use this list as a set of dumby stopwords and store in a StopWordsRemover instance
        :return: Java object equivalent to this instance.
        """
        dmp = dill.dumps(self)
        pylist = [str(ord(d)) for d in dmp] # convert byes to string integer list
        pylist.append(PysparkObjId._getPyObjId()) # add our id so PysparkPipelineWrapper can id us.
        sc = SparkContext._active_spark_context
        java_class = sc._gateway.jvm.java.lang.String
        java_array = sc._gateway.new_array(java_class, len(pylist))
        for i in xrange(len(pylist)):
            java_array[i] = pylist[i]
        _java_obj = JavaParams._new_java_obj(PysparkObjId._getCarrierClass(javaName=True), self.uid)
        _java_obj.setStopWords(java_array)
        return _java_obj

class HasFake(Params):
    def __init__(self):
        super(HasFake, self).__init__()
        self.fake = Param(self, "fake", "fake param")

    def getFake(self):
        return self.getOrDefault(self.fake)

class MockTransformer(Transformer, HasFake, Identifiable):
    def __init__(self):
        super(MockTransformer, self).__init__()
        self.dataset_count = 0

    def _transform(self, dataset):
        self.dataset_count = dataset.count()
        return dataset

class MyTransformer(MockTransformer, Identifiable, PysparkReaderWriter, MLReadable, MLWritable):
    def __init__(self):
        super(MyTransformer, self).__init__()

def make_a_dataframe(sc):
    df = sc.parallelize([Row(name='Alice', age=5, height=80), Row(name='Alice', age=5, height=80), Row(name='Alice', age=10, height=80)]).toDF()
    return df

def test1():
    trA = MyTransformer()
    trA.dataset_count = 999
    print trA.dataset_count
    trA.save('test.trans')
    trB = MyTransformer.load('test.trans')
    print trB.dataset_count

def test2():
    trA = MyTransformer()
    pipeA = Pipeline(stages=[trA])
    print type(pipeA)
    pipeA.save('testA.pipe')
    pipeAA = PysparkPipelineWrapper.unwrap(Pipeline.load('testA.pipe'))
    stagesAA = pipeAA.getStages()
    trAA = stagesAA[0]
    print trAA.dataset_count

def test3():
    dfA = make_a_dataframe(sc)
    trA = MyTransformer()
    pipeA = Pipeline(stages=[trA]).fit(dfA)
    print type(pipeA)
    pipeA.save('testB.pipe')
    pipeAA = PysparkPipelineWrapper.unwrap(PipelineModel.load('testB.pipe'))
    stagesAA = pipeAA.stages
    trAA = stagesAA[0]
    print trAA.dataset_count
    dfB = pipeAA.transform(dfA)
    dfB.show()
Run Code Online (Sandbox Code Playgroud)