小编cst*_*eel的帖子

如何使用 pyspark 从 Spark 获取批量行

我有一个包含超过 60 亿行数据的 Spark RDD,我想使用 train_on_batch 来训练深度学习模型。我无法将所有行放入内存中,因此我希望一次获得 10K 左右的数据,以批量处理为 64 或 128 块(取决于模型大小)。我目前正在使用 rdd.sample() 但我不认为这保证我会获得所有行。有没有更好的方法来分区数据以使其更易于管理,以便我可以编写生成器函数来获取批次?我的代码如下:

data_df = spark.read.parquet(PARQUET_FILE)
print(f'RDD Count: {data_df.count()}') # 6B+
data_sample = data_df.sample(True, 0.0000015).take(6400) 
sample_df = data_sample.toPandas()

def get_batch():
  for row in sample_df.itertuples():
    # TODO: put together a batch size of BATCH_SIZE
    yield row

for i in range(10):
    print(next(get_batch()))
Run Code Online (Sandbox Code Playgroud)

python apache-spark rdd pyspark

8
推荐指数
2
解决办法
1万
查看次数

标签 统计

apache-spark ×1

pyspark ×1

python ×1

rdd ×1