如何检查给定键的所有记录是否已在同一分区中?

Jac*_*ski 7 apache-spark

我想尽可能避免按密钥重新分区数据,并知道给定密钥的所有记录是否已经在同一个分区中.

Spark中是否有内置函数可以给我答案?

zer*_*323 1

不是内置的,但如果您假设特定的分区器,则很容易实现您自己的功能:

import org.apache.spark.rdd.RDD
import org.apache.spark.Partitioner
import scala.reflect.ClassTag

def checkDistribution[K : ClassTag, V : ClassTag](
   rdd: RDD[(K, V)], partitioner: Partitioner) = 
  // If partitioner is set we compare partitioners 
  rdd.partitioner.map(_ == partitioner).getOrElse {
    // Otherwise check if correct number of partitions 
    rdd.partitions.size ==  partitioner.numPartitions &&
    //  And check if distribution matches partitioner
    rdd.keys.mapPartitionsWithIndex((i, iter) => 
      Iterator(iter.forall(x => partitioner.getPartition(x) == i))
    ).fold(true)(_ && _)
  }
Run Code Online (Sandbox Code Playgroud)

一些测试:

import org.apache.spark.HashPartitioner

val rdd = sc.range(0, 20, 5).map((_, None))
Run Code Online (Sandbox Code Playgroud)

在不假设特定分区程序的情况下,想到的唯一选择需要洗牌,因此它不太可能是一种改进。

def checkDistribution[K : ClassTag, V : ClassTag](rdd: RDD[(K, V)]) =
   rdd.keys.mapPartitionsWithIndex((i, iter) => iter.map((_, i)))
     .combineByKey(
       x => Seq(x), 
       (x: Seq[Int], y: Int) => x, 
       (x: Seq[Int], y: Seq[Int]) => x ++ y)  // Should be more or less OK
     .values
     .mapPartitions(iter => Iterator(iter.forall(_.size == 1)))
     .fold(true)(_ && _)
Run Code Online (Sandbox Code Playgroud)

一项可能的改进是您可以使用相同的逻辑来自动定义Partitioner数据。如果您collectAsMap之前values检查过所有Seqs大小均为 1,那么您就有了一个有效的分区器,可以保证没有网络流量。