Spark中的分层抽样

add*_*ons 19 scala apache-spark

我有包含用户和购买数据的数据集.下面是一个示例,其中第一个元素是userId,第二个元素是productId,第三个元素是boolean.

(2147481832,23355149,1)
(2147481832,973010692,1)
(2147481832,2134870842,1)
(2147481832,541023347,1)
(2147481832,1682206630,1)
(2147481832,1138211459,1)
(2147481832,852202566,1)
(2147481832,201375938,1)
(2147481832,486538879,1)
(2147481832,919187908,1)
... 
Run Code Online (Sandbox Code Playgroud)

我想确保我只占用每个用户数据的80%并构建RDD,同时占用20%的剩余部分并构建另一个RDD.让我们来电话和测试.我想远离使用groupBy开始,因为它可以创建内存问题,因为数据集很大.什么是最好的方法呢?

我可以做以下但这不会给每个用户80%.

val percentData = data.map(x => ((math.random * 100).toInt, x._1. x._2, x._3)
val train = percentData.filter(x => x._1 < 80).values.repartition(10).cache()
Run Code Online (Sandbox Code Playgroud)

eli*_*sah 23

一个可能的解决方案是霍尔顿的答案,这里有一些其他解决方案:

使用RDD:

您可以使用PairRDDFunctions类中的sampleByKeyExact转换.

sampleByKeyExact(boolean withReplacement,scala.collection.Map fractions,long seed)返回通过key(通过分层抽样)采样的此RDD的子集,其中包含每个层的完全math.ceil(numItems*samplingRate)(具有相同键的对的组) ).

这就是我要做的:

考虑以下列表:

val seq = Seq(
                (2147481832,23355149,1),(2147481832,973010692,1),(2147481832,2134870842,1),(2147481832,541023347,1),
                (2147481832,1682206630,1),(2147481832,1138211459,1),(2147481832,852202566,1),(2147481832,201375938,1),
                (2147481832,486538879,1),(2147481832,919187908,1),(214748183,919187908,1),(214748183,91187908,1)
           )
Run Code Online (Sandbox Code Playgroud)

我会创建一个RDDPair,将所有用户映射为键:

val data = sc.parallelize(seq).map(x => (x._1,(x._2,x._3)))
Run Code Online (Sandbox Code Playgroud)

然后我将为fractions每个键设置如下,因为为每个键 sampleByKeyExact获取一个分数的Map:

val fractions = data.map(_._1).distinct.map(x => (x,0.8)).collectAsMap
Run Code Online (Sandbox Code Playgroud)

我在这里做的是在键上映射以找到不同的键,然后将每个键与等于的分数相关联0.8.我将整体收集为地图.

现在来样品:

import org.apache.spark.rdd.PairRDDFunctions
val sampleData = data.sampleByKeyExact(false, fractions, 2L)
Run Code Online (Sandbox Code Playgroud)

要么

val sampleData = data.sampleByKeyExact(withReplacement = false, fractions = fractions,seed = 2L)
Run Code Online (Sandbox Code Playgroud)

您可以检查密钥或数据或数据样本的计数:

scala > data.count
// [...]
// res10: Long = 12

scala > sampleData.count
// [...]
// res11: Long = 10
Run Code Online (Sandbox Code Playgroud)

使用DataFrames:

让我们考虑seq前一节中的相同数据().

val df = seq.toDF("keyColumn","value1","value2")
df.show
// +----------+----------+------+
// | keyColumn|    value1|value2|
// +----------+----------+------+
// |2147481832|  23355149|     1|
// |2147481832| 973010692|     1|
// |2147481832|2134870842|     1|
// |2147481832| 541023347|     1|
// |2147481832|1682206630|     1|
// |2147481832|1138211459|     1|
// |2147481832| 852202566|     1|
// |2147481832| 201375938|     1|
// |2147481832| 486538879|     1|
// |2147481832| 919187908|     1|
// | 214748183| 919187908|     1|
// | 214748183|  91187908|     1|
// +----------+----------+------+
Run Code Online (Sandbox Code Playgroud)

我们将需要底层RDD来做我们RDD通过将我们的键定义为第一列来创建元素的元组:

val data: RDD[(Int, Row)] = df.rdd.keyBy(_.getInt(0))
val fractions: Map[Int, Double] = data.map(_._1)
                                      .distinct
                                      .map(x => (x, 0.8))
                                      .collectAsMap

val sampleData: RDD[Row] = data.sampleByKeyExact(withReplacement = false, fractions, 2L)
                               .values

val sampleDataDF: DataFrame = spark.createDataFrame(sampleData, df.schema) // you can use sqlContext.createDataFrame(...) instead for spark 1.6)
Run Code Online (Sandbox Code Playgroud)

您现在可以检查密钥df或数据样本的计数:

scala > df.count
// [...]
// res9: Long = 12

scala > sampleDataDF.count
// [...]
// res10: Long = 10
Run Code Online (Sandbox Code Playgroud)

Spark 1.5.0开始,你可以使用DataFrameStatFunctions.sampleBy方法:

df.stat.sampleBy("keyColumn", fractions, seed)
Run Code Online (Sandbox Code Playgroud)