从早期行累积数组(PySpark数据帧)

xen*_*yon 9 dataframe apache-spark apache-spark-sql pyspark

一个(Python)示例将使我的问题清楚.假设我有一个Spark数据框,其中包含在特定日期观看某些电影的人,如下所示:

movierecord = spark.createDataFrame([("Alice", 1, ["Avatar"]),("Bob", 2, ["Fargo", "Tron"]),("Alice", 4, ["Babe"]), ("Alice", 6, ["Avatar", "Airplane"]), ("Alice", 7, ["Pulp Fiction"]), ("Bob", 9, ["Star Wars"])],["name","unixdate","movies"])
Run Code Online (Sandbox Code Playgroud)

上面定义的模式和数据框如下所示:

root
 |-- name: string (nullable = true)
 |-- unixdate: long (nullable = true)
 |-- movies: array (nullable = true)
 |    |-- element: string (containsNull = true)

+-----+--------+------------------+
|name |unixdate|movies            |
+-----+--------+------------------+
|Alice|1       |[Avatar]          |
|Bob  |2       |[Fargo, Tron]     |
|Alice|4       |[Babe]            |
|Alice|6       |[Avatar, Airplane]|
|Alice|7       |[Pulp Fiction]    |
|Bob  |9       |[Star Wars]       |
+-----+--------+------------------+
Run Code Online (Sandbox Code Playgroud)

我想从上面开始生成一个新的数据帧列,其中包含每个用户看到的所有以前的电影,没有重复(每个unixdate字段为"previous").所以看起来应该是这样的:

+-----+--------+------------------+------------------------+
|name |unixdate|movies            |previous_movies         |
+-----+--------+------------------+------------------------+
|Alice|1       |[Avatar]          |[]                      |
|Bob  |2       |[Fargo, Tron]     |[]                      |
|Alice|4       |[Babe]            |[Avatar]                |
|Alice|6       |[Avatar, Airplane]|[Avatar, Babe]          |
|Alice|7       |[Pulp Fiction]    |[Avatar, Babe, Airplane]|
|Bob  |9       |[Star Wars]       |[Fargo, Tron]           |
+-----+--------+------------------+------------------------+
Run Code Online (Sandbox Code Playgroud)

我如何以一种有效的方式实现这一点?

use*_*411 8

仅在 不保留对象顺序情况下使用SQL:

  • 所需进口:

    import pyspark.sql.functions as f
    from pyspark.sql.window import Window
    
    Run Code Online (Sandbox Code Playgroud)
  • 窗口定义:

    w = Window.partitionBy("name").orderBy("unixdate")
    
    Run Code Online (Sandbox Code Playgroud)
  • 完整解决方案

    (movierecord
        # Flatten movies
        .withColumn("previous_movie", f.explode("movies"))
        # Collect unique
        .withColumn("previous_movies", f.collect_set("previous_movie").over(w))
        # Drop duplicates for a single unixdate
        .groupBy("name", "unixdate")
        .agg(f.max(f.struct(
            f.size("previous_movies"),
            f.col("movies").alias("movies"),
            f.col("previous_movies").alias("previous_movies")
        )).alias("tmp"))
        # Shift by one and extract
       .select(
           "name", "unixdate", "tmp.movies", 
           f.lag("tmp.previous_movies", 1).over(w).alias("previous_movies")))
    
    Run Code Online (Sandbox Code Playgroud)
  • 结果:

     +-----+--------+------------------+------------------------+
     |name |unixdate|movies            |previous_movies         |
     +-----+--------+------------------+------------------------+
     |Bob  |2       |[Fargo, Tron]     |null                    |
     |Bob  |9       |[Star Wars]       |[Fargo, Tron]           |
     |Alice|1       |[Avatar]          |null                    |
     |Alice|4       |[Babe]            |[Avatar]                |
     |Alice|6       |[Avatar, Airplane]|[Babe, Avatar]          |
     |Alice|7       |[Pulp Fiction]    |[Babe, Airplane, Avatar]|
     +-----+--------+------------------+------------------------+
    
    Run Code Online (Sandbox Code Playgroud)

SQL和Python UDF保留顺序:

  • 进口:

    import pyspark.sql.functions as f
    from pyspark.sql.window import Window
    from pyspark.sql import Column
    from pyspark.sql.types import ArrayType, StringType
    
    from typing import List, Union
    
    # https://github.com/pytoolz/toolz
    from toolz import unique, concat, compose
    
    Run Code Online (Sandbox Code Playgroud)
  • UDF:

    def flatten_distinct(col: Union[Column, str]) -> Column:
        def flatten_distinct_(xss: Union[List[List[str]], None]) -> List[str]:
            return compose(list, unique, concat)(xss or [])
        return f.udf(flatten_distinct_, ArrayType(StringType()))(col)
    
    Run Code Online (Sandbox Code Playgroud)
  • 窗口定义和以前一样.

  • 完整解决方案

    (movierecord
        # Collect lists
        .withColumn("previous_movies", f.collect_list("movies").over(w))
        # Flatten and drop duplicates
        .withColumn("previous_movies", flatten_distinct("previous_movies"))
        # Shift by one
        .withColumn("previous_movies", f.lag("previous_movies", 1).over(w))
        # For presentation only
        .orderBy("unixdate")) 
    
    Run Code Online (Sandbox Code Playgroud)
  • 结果:

    +-----+--------+------------------+------------------------+
    |name |unixdate|movies            |previous_movies         |
    +-----+--------+------------------+------------------------+
    |Alice|1       |[Avatar]          |null                    |
    |Bob  |2       |[Fargo, Tron]     |null                    |
    |Alice|4       |[Babe]            |[Avatar]                |
    |Alice|6       |[Avatar, Airplane]|[Avatar, Babe]          |
    |Alice|7       |[Pulp Fiction]    |[Avatar, Babe, Airplane]|
    |Bob  |9       |[Star Wars]       |[Fargo, Tron]           |
    +-----+--------+------------------+------------------------+
    
    Run Code Online (Sandbox Code Playgroud)

表现:

我认为,鉴于这些限制,没有有效的方法可以解决这个问题.不仅请求的输出需要大量的数据重复(数据是二进制编码以适合Tungsten格式,因此您可以获得压缩但是对象身份松散),而且许多操作都很昂贵,因为Spark计算模型包括昂贵的分组和排序.

如果期望的大小previous_movies有限且很小但是通常不可行,那么这应该没问题.

通过为用户保留单一,懒惰的历史记录,很容易解决数据复制问题.不是可以在SQL中完成的东西,但是使用低级RDD操作非常容易.

爆炸和collect_模式很昂贵.如果您的要求很严格但想要提高性能,可以使用Scala UDF代替Python.