获取列中数组的相关矩阵

Ror*_*ory 8 apache-spark apache-spark-sql pyspark

我有数据框:

data = [['t1', ['u1','u2', 'u3', 'u4', 'u5'], 1],['t2', ['u1','u7', 'u8', 'u5'], 1], ['t3', ['u1','u2', 'u7', 'u11'], 2], ['t4', ['u8','u9'], 3], ['t5', ['u9','u22', 'u11'], 3],
       ['t6', ['u5','u11', 'u22', 'u4'], 3]]
sdf = spark.createDataFrame(data, schema=['label', 'id', 'day'])
sdf.show()
+-----+--------------------+---+
|label|                  id|day|
+-----+--------------------+---+
|   t1|[u1, u2, u3, u4, u5]|  1|
|   t2|    [u1, u7, u8, u5]|  1|
|   t3|   [u1, u2, u7, u11]|  2|
|   t4|            [u8, u9]|  3|
|   t5|      [u9, u22, u11]|  3|
|   t6|  [u5, u11, u22, u4]|  3|
+-----+--------------------+---+
Run Code Online (Sandbox Code Playgroud)

我想计算相关矩阵(实际上我的数据框要大得多):

我想id column每隔一天穿越一次。也就是说,在day=1的时候,我当天不交叉ID,这样的情况就设置为0。我将第一天与第二天和第三天交叉,依此类推。

而且,如果标签与自身相交,则不是 100,而是给出 0(对角线为 0)。

在矩阵中,我想记录交集的绝对值(有多少个ID相交)它可能应该产生这样的数据框:

+---+-----+---+---+---+---+---+---+
|day|label| t1| t2| t3| t4| t5| t6|
+---+-----+---+---+---+---+---+---+
|  1|   t1|  0|  0|  2|  0|  0|  2|
|  1|   t2|  0|  0|  2|  1|  0|  0|
|  2|   t3|  2|  2|  0|  0|  0|  1|
|  3|   t4|  0|  1|  0|  0|  0|  0|
|  3|   t5|  0|  0|  1|  0|  0|  0|
|  3|   t6|  2|  1|  1|  0|  0|  0|
+---+-----+---+---+---+---+---+---+
Run Code Online (Sandbox Code Playgroud)

因为我实际上有一个很大的数据集,所以我希望它不要需要太多内存,并且任务不会落下

Ber*_*ler 5

首先,您可以使用explode扁平化 ID 列表:

>>> from pyspark.sql.functions import explode
>>> from pyspark.sql.types import StructType, StructField, StringType, ArrayType
>>> schema = StructType([
...     StructField('label', StringType(), nullable=False),
...     StructField('ids', ArrayType(StringType(), containsNull=False), nullable=False),
...     StructField('day', StringType(), nullable=False),
... ])
>>> data = [
...     ['t1', ['u1', 'u2', 'u3', 'u4', 'u5'], 1],
...     ['t2', ['u1', 'u7', 'u8', 'u5'], 1],
...     ['t3', ['u1', 'u2', 'u7', 'u11'], 2],
...     ['t4', ['u8', 'u9'], 3],
...     ['t5', ['u9', 'u22', 'u11'], 3],
...     ['t6', ['u5', 'u11', 'u22', 'u4'], 3]
... ]
>>> id_lists_df = spark.createDataFrame(data, schema=schema)
>>> df = id_lists_df.select('label', 'day', explode('ids').alias('id'))
>>> df.show()
+-----+---+---+                                                                 
|label|day| id|
+-----+---+---+
|   t1|  1| u1|
|   t1|  1| u2|
|   t1|  1| u3|
|   t1|  1| u4|
|   t1|  1| u5|
|   t2|  1| u1|
|   t2|  1| u7|
|   t2|  1| u8|
|   t2|  1| u5|
|   t3|  2| u1|
|   t3|  2| u2|
|   t3|  2| u7|
|   t3|  2|u11|
|   t4|  3| u8|
|   t4|  3| u9|
|   t5|  3| u9|
|   t5|  3|u22|
|   t5|  3|u11|
|   t6|  3| u5|
|   t6|  3|u11|
+-----+---+---+
only showing top 20 rows
Run Code Online (Sandbox Code Playgroud)

然后,您可以自连接结果数据框,过滤掉不需要的行(同一天或标签),然后继续进行实际计数。

我的印象是你的矩阵将包含很多零。您是否需要“物理”矩阵,或者每天的计数和一对标签就足够了?

如果不需要“物理”矩阵,可以使用常规聚合(按天和标签分组,然后计数):

>>> df2 = df.withColumnRenamed('label', 'label2').withColumnRenamed('day', 'day2')
>>> counts = df.join(df2, on='id') \
...     .where(df.label != df2.label2) \
...     .where(df.day != df2.day2) \
...     .groupby(df.day, df.label, df2.label2) \
...     .count() \
...     .orderBy(df.label, df2.label2)
>>> 
>>> counts.show()
+---+-----+------+-----+                                                        
|day|label|label2|count|
+---+-----+------+-----+
|  1|   t1|    t3|    2|
|  1|   t1|    t6|    2|
|  1|   t2|    t3|    2|
|  1|   t2|    t4|    1|
|  1|   t2|    t6|    1|
|  2|   t3|    t1|    2|
|  2|   t3|    t2|    2|
|  2|   t3|    t5|    1|
|  2|   t3|    t6|    1|
|  3|   t4|    t2|    1|
|  3|   t5|    t3|    1|
|  3|   t6|    t1|    2|
|  3|   t6|    t2|    1|
|  3|   t6|    t3|    1|
+---+-----+------+-----+

>>> counts.explain()
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Sort [label#0 ASC NULLS FIRST, label2#493 ASC NULLS FIRST], true, 0
   +- Exchange rangepartitioning(label#0 ASC NULLS FIRST, label2#493 ASC NULLS FIRST, 200), ENSURE_REQUIREMENTS, [plan_id=2010]
      +- HashAggregate(keys=[day#2, label#0, label2#493], functions=[count(1)])
         +- Exchange hashpartitioning(day#2, label#0, label2#493, 200), ENSURE_REQUIREMENTS, [plan_id=2007]
            +- HashAggregate(keys=[day#2, label#0, label2#493], functions=[partial_count(1)])
               +- Project [label#0, day#2, label2#493]
                  +- SortMergeJoin [id#7], [id#504], Inner, (NOT (label#0 = label2#493) AND NOT (day#2 = day2#497))
                     :- Sort [id#7 ASC NULLS FIRST], false, 0
                     :  +- Exchange hashpartitioning(id#7, 200), ENSURE_REQUIREMENTS, [plan_id=1999]
                     :     +- Generate explode(ids#1), [label#0, day#2], false, [id#7]
                     :        +- Filter (size(ids#1, true) > 0)
                     :           +- Scan ExistingRDD[label#0,ids#1,day#2]
                     +- Sort [id#504 ASC NULLS FIRST], false, 0
                        +- Exchange hashpartitioning(id#504, 200), ENSURE_REQUIREMENTS, [plan_id=2000]
                           +- Project [label#501 AS label2#493, day#503 AS day2#497, id#504]
                              +- Generate explode(ids#502), [label#501, day#503], false, [id#504]
                                 +- Filter (size(ids#502, true) > 0)
                                    +- Scan ExistingRDD[label#501,ids#502,day#503]
Run Code Online (Sandbox Code Playgroud)

如果您需要“物理”矩阵,您可以按照第一个答案中的建议使用 MLlib,或者您可以使用pivotonlabel2而不是将其用作分组列:

>>> counts_pivoted = df.join(df2, on='id') \
...     .where(df.label != df2.label2) \
...     .where(df.day != df2.day2) \
...     .groupby(df.day, df.label) \
...     .pivot('label2') \
...     .count() \
...     .drop('label2') \
...     .orderBy('label') \
...     .fillna(0)
>>> counts_pivoted.show()                                                       
+---+-----+---+---+---+---+---+---+                                             
|day|label| t1| t2| t3| t4| t5| t6|
+---+-----+---+---+---+---+---+---+
|  1|   t1|  0|  0|  2|  0|  0|  2|
|  1|   t2|  0|  0|  2|  1|  0|  1|
|  2|   t3|  2|  2|  0|  0|  1|  1|
|  3|   t4|  0|  1|  0|  0|  0|  0|
|  3|   t5|  0|  0|  1|  0|  0|  0|
|  3|   t6|  2|  1|  1|  0|  0|  0|
+---+-----+---+---+---+---+---+---+

>>> counts_pivoted.explain()
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [day#2, label#0, coalesce(t1#574L, 0) AS t1#616L, coalesce(t2#575L, 0) AS t2#617L, coalesce(t3#576L, 0) AS t3#618L, coalesce(t4#577L, 0) AS t4#619L, coalesce(t5#578L, 0) AS t5#620L, coalesce(t6#579L, 0) AS t6#621L]
   +- Sort [label#0 ASC NULLS FIRST], true, 0
      +- Exchange rangepartitioning(label#0 ASC NULLS FIRST, 200), ENSURE_REQUIREMENTS, [plan_id=2744]
         +- Project [day#2, label#0, __pivot_count(1) AS count AS `count(1) AS count`#573[0] AS t1#574L, __pivot_count(1) AS count AS `count(1) AS count`#573[1] AS t2#575L, __pivot_count(1) AS count AS `count(1) AS count`#573[2] AS t3#576L, __pivot_count(1) AS count AS `count(1) AS count`#573[3] AS t4#577L, __pivot_count(1) AS count AS `count(1) AS count`#573[4] AS t5#578L, __pivot_count(1) AS count AS `count(1) AS count`#573[5] AS t6#579L]
            +- HashAggregate(keys=[day#2, label#0], functions=[pivotfirst(label2#493, count(1) AS count#559L, t1, t2, t3, t4, t5, t6, 0, 0)])
               +- Exchange hashpartitioning(day#2, label#0, 200), ENSURE_REQUIREMENTS, [plan_id=2740]
                  +- HashAggregate(keys=[day#2, label#0], functions=[partial_pivotfirst(label2#493, count(1) AS count#559L, t1, t2, t3, t4, t5, t6, 0, 0)])
                     +- HashAggregate(keys=[day#2, label#0, label2#493], functions=[count(1)])
                        +- Exchange hashpartitioning(day#2, label#0, label2#493, 200), ENSURE_REQUIREMENTS, [plan_id=2736]
                           +- HashAggregate(keys=[day#2, label#0, label2#493], functions=[partial_count(1)])
                              +- Project [label#0, day#2, label2#493]
                                 +- SortMergeJoin [id#7], [id#543], Inner, (NOT (label#0 = label2#493) AND NOT (day#2 = day2#497))
                                    :- Sort [id#7 ASC NULLS FIRST], false, 0
                                    :  +- Exchange hashpartitioning(id#7, 200), ENSURE_REQUIREMENTS, [plan_id=2728]
                                    :     +- Generate explode(ids#1), [label#0, day#2], false, [id#7]
                                    :        +- Filter (size(ids#1, true) > 0)
                                    :           +- Scan ExistingRDD[label#0,ids#1,day#2]
                                    +- Sort [id#543 ASC NULLS FIRST], false, 0
                                       +- Exchange hashpartitioning(id#543, 200), ENSURE_REQUIREMENTS, [plan_id=2729]
                                          +- Project [label#540 AS label2#493, day#542 AS day2#497, id#543]
                                             +- Generate explode(ids#541), [label#540, day#542], false, [id#543]
                                                +- Filter (size(ids#541, true) > 0)
                                                   +- Scan ExistingRDD[label#540,ids#541,day#542]
Run Code Online (Sandbox Code Playgroud)

这些值与您的示例并不完全相同,但我认为维尔纳的评论是正确的。

pivot变体可能效率较低。如果可能的标签列表事先可用,您可以通过将其作为 的第二个参数传递来节省一些时间pivot