J *_*ath 52 pivot scala dataframe apache-spark apache-spark-sql
我开始使用Spark DataFrames,我需要能够透过数据来创建多列的1列中的多列.在Scalding中有内置的功能,我相信Python中的Pandas,但我找不到任何新的Spark Dataframe.
我假设我可以编写某种类型的自定义函数,但是我甚至不确定如何启动,特别是因为我是Spark的新手.我有人知道如何使用内置功能或如何在Scala中编写内容的建议,非常感谢.
zer*_*323 69
正如David Anderson 提到的,Spark pivot从版本1.6开始提供功能.一般语法如下所示:
df
.groupBy(grouping_columns)
.pivot(pivot_column, [values])
.agg(aggregate_expressions)
Run Code Online (Sandbox Code Playgroud)
用法示例使用nycflights13和csv格式:
Python:
from pyspark.sql.functions import avg
flights = (sqlContext
.read
.format("csv")
.options(inferSchema="true", header="true")
.load("flights.csv")
.na.drop())
flights.registerTempTable("flights")
sqlContext.cacheTable("flights")
gexprs = ("origin", "dest", "carrier")
aggexpr = avg("arr_delay")
flights.count()
## 336776
%timeit -n10 flights.groupBy(*gexprs ).pivot("hour").agg(aggexpr).count()
## 10 loops, best of 3: 1.03 s per loop
Run Code Online (Sandbox Code Playgroud)
斯卡拉:
val flights = sqlContext
.read
.format("csv")
.options(Map("inferSchema" -> "true", "header" -> "true"))
.load("flights.csv")
flights
.groupBy($"origin", $"dest", $"carrier")
.pivot("hour")
.agg(avg($"arr_delay"))
Run Code Online (Sandbox Code Playgroud)
Java:
import static org.apache.spark.sql.functions.*;
import org.apache.spark.sql.*;
Dataset<Row> df = spark.read().format("csv")
.option("inferSchema", "true")
.option("header", "true")
.load("flights.csv");
df.groupBy(col("origin"), col("dest"), col("carrier"))
.pivot("hour")
.agg(avg(col("arr_delay")));
Run Code Online (Sandbox Code Playgroud)
R/SparkR:
library(magrittr)
flights <- read.df("flights.csv", source="csv", header=TRUE, inferSchema=TRUE)
flights %>%
groupBy("origin", "dest", "carrier") %>%
pivot("hour") %>%
agg(avg(column("arr_delay")))
Run Code Online (Sandbox Code Playgroud)
R/sparklyr
library(dplyr)
flights <- spark_read_csv(sc, "flights", "flights.csv")
avg.arr.delay <- function(gdf) {
expr <- invoke_static(
sc,
"org.apache.spark.sql.functions",
"avg",
"arr_delay"
)
gdf %>% invoke("agg", expr, list())
}
flights %>%
sdf_pivot(origin + dest + carrier ~ hour, fun.aggregate=avg.arr.delay)
Run Code Online (Sandbox Code Playgroud)
SQL:
CREATE TEMPORARY VIEW flights
USING csv
OPTIONS (header 'true', path 'flights.csv', inferSchema 'true') ;
SELECT * FROM (
SELECT origin, dest, carrier, arr_delay, hour FROM flights
) PIVOT (
avg(arr_delay)
FOR hour IN (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23)
);
Run Code Online (Sandbox Code Playgroud)
示例数据:
"year","month","day","dep_time","sched_dep_time","dep_delay","arr_time","sched_arr_time","arr_delay","carrier","flight","tailnum","origin","dest","air_time","distance","hour","minute","time_hour"
2013,1,1,517,515,2,830,819,11,"UA",1545,"N14228","EWR","IAH",227,1400,5,15,2013-01-01 05:00:00
2013,1,1,533,529,4,850,830,20,"UA",1714,"N24211","LGA","IAH",227,1416,5,29,2013-01-01 05:00:00
2013,1,1,542,540,2,923,850,33,"AA",1141,"N619AA","JFK","MIA",160,1089,5,40,2013-01-01 05:00:00
2013,1,1,544,545,-1,1004,1022,-18,"B6",725,"N804JB","JFK","BQN",183,1576,5,45,2013-01-01 05:00:00
2013,1,1,554,600,-6,812,837,-25,"DL",461,"N668DN","LGA","ATL",116,762,6,0,2013-01-01 06:00:00
2013,1,1,554,558,-4,740,728,12,"UA",1696,"N39463","EWR","ORD",150,719,5,58,2013-01-01 05:00:00
2013,1,1,555,600,-5,913,854,19,"B6",507,"N516JB","EWR","FLL",158,1065,6,0,2013-01-01 06:00:00
2013,1,1,557,600,-3,709,723,-14,"EV",5708,"N829AS","LGA","IAD",53,229,6,0,2013-01-01 06:00:00
2013,1,1,557,600,-3,838,846,-8,"B6",79,"N593JB","JFK","MCO",140,944,6,0,2013-01-01 06:00:00
2013,1,1,558,600,-2,753,745,8,"AA",301,"N3ALAA","LGA","ORD",138,733,6,0,2013-01-01 06:00:00
Run Code Online (Sandbox Code Playgroud)
性能考虑:
一般而言,枢转是一种昂贵的操作.
如果你可以尝试提供values清单:
vs = list(range(25))
%timeit -n10 flights.groupBy(*gexprs ).pivot("hour", vs).agg(aggexpr).count()
## 10 loops, best of 3: 392 ms per loop
Run Code Online (Sandbox Code Playgroud)在某些情况下,它被证明是有益的(可能不再值得在2.0或更高版本中努力)repartition和/或预先聚合数据
仅用于重新整形,您可以使用first:如何使用数据透视表并计算非数字列的平均值(面向AnalysisException"不是数字列")?
相关问题:
J *_*ath 14
我通过编写for循环来动态创建SQL查询来克服这个问题.说我有:
id tag value
1 US 50
1 UK 100
1 Can 125
2 US 75
2 UK 150
2 Can 175
Run Code Online (Sandbox Code Playgroud)
而且我要:
id US UK Can
1 50 100 125
2 75 150 175
Run Code Online (Sandbox Code Playgroud)
我可以创建一个包含我想要转动的值的列表,然后创建一个包含我需要的SQL查询的字符串.
val countries = List("US", "UK", "Can")
val numCountries = countries.length - 1
var query = "select *, "
for (i <- 0 to numCountries-1) {
query += """case when tag = """" + countries(i) + """" then value else 0 end as """ + countries(i) + ", "
}
query += """case when tag = """" + countries.last + """" then value else 0 end as """ + countries.last + " from myTable"
myDataFrame.registerTempTable("myTable")
val myDF1 = sqlContext.sql(query)
Run Code Online (Sandbox Code Playgroud)
我可以创建类似的查询然后进行聚合.这不是一个非常优雅的解决方案,但它适用于任何值列表,并且在调用代码时也可以作为参数传递.
小智 5
我使用数据帧解决了类似的问题,步骤如下:
为您的所有国家/地区创建列,并将"值"作为值:
import org.apache.spark.sql.functions._
val countries = List("US", "UK", "Can")
val countryValue = udf{(countryToCheck: String, countryInRow: String, value: Long) =>
if(countryToCheck == countryInRow) value else 0
}
val countryFuncs = countries.map{country => (dataFrame: DataFrame) => dataFrame.withColumn(country, countryValue(lit(country), df("tag"), df("value"))) }
val dfWithCountries = Function.chain(countryFuncs)(df).drop("tag").drop("value")
Run Code Online (Sandbox Code Playgroud)
您的数据框'dfWithCountries'将如下所示:
+--+--+---+---+
|id|US| UK|Can|
+--+--+---+---+
| 1|50| 0| 0|
| 1| 0|100| 0|
| 1| 0| 0|125|
| 2|75| 0| 0|
| 2| 0|150| 0|
| 2| 0| 0|175|
+--+--+---+---+
Run Code Online (Sandbox Code Playgroud)
现在,您可以将所需结果的所有值相加:
dfWithCountries.groupBy("id").sum(countries: _*).show
Run Code Online (Sandbox Code Playgroud)
结果:
+--+-------+-------+--------+
|id|SUM(US)|SUM(UK)|SUM(Can)|
+--+-------+-------+--------+
| 1| 50| 100| 125|
| 2| 75| 150| 175|
+--+-------+-------+--------+
Run Code Online (Sandbox Code Playgroud)
虽然这不是一个非常优雅的解决方案.我必须创建一系列函数来添加所有列.此外,如果我有很多国家,我会将我的临时数据集扩展到一个非常宽的集合,有很多零.
有一个简单的旋转方法:
id tag value
1 US 50
1 UK 100
1 Can 125
2 US 75
2 UK 150
2 Can 175
import sparkSession.implicits._
val data = Seq(
(1,"US",50),
(1,"UK",100),
(1,"Can",125),
(2,"US",75),
(2,"UK",150),
(2,"Can",175),
)
val dataFrame = data.toDF("id","tag","value")
val df2 = dataFrame
.groupBy("id")
.pivot("tag")
.max("value")
df2.show()
+---+---+---+---+
| id|Can| UK| US|
+---+---+---+---+
| 1|125|100| 50|
| 2|175|150| 75|
+---+---+---+---+
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
40163 次 |
| 最近记录: |