使用dplyr full_join为data.frame展开expand.grid

Bam*_*mqf 5 merge join r dplyr

我正在尝试获得类似于的功能expand.grid并在上工作data.frame

我在替代方法中为data.frames的expand.grid找到了一种解决方案,该解决方案 使用merge函数来实现此目的。

由于mergedplyrAlternative 相比速度相当慢full_join,因此我尝试使用它full_join来实现此功能,但无法正确完成。这是我失败的示例:

df <- data.frame(attribute = paste0('attr', rep(1:5, each=2)),
                 value = paste0(rep(1:5, each=2), rep(c('A','B'), 2)),
                 score = runif(10))
df
   attribute value      score
1      attr1    1A 0.75600171
2      attr1    1B 0.07086242
3      attr2    2A 0.92403325
4      attr2    2B 0.63414169
5      attr3    3A 0.78763834
6      attr3    3B 0.88576568
7      attr4    4A 0.75998967
8      attr4    4B 0.25205845
9      attr5    5A 0.99304728
10     attr5    5B 0.70389605
Run Code Online (Sandbox Code Playgroud)

我试图分裂df通过attribute并加入评分者名单:

dfList <- df %>%
  mutate(attribute=1) %>%
  split(df$attribute)
Run Code Online (Sandbox Code Playgroud)

然后通过以下方式将所有这5个表“ expand.grid”在一起:

Reduce(function(x, y) {full_join(x, y, by=c('attribute'='attribute'))}, dfList)
Run Code Online (Sandbox Code Playgroud)

但是,结果很奇怪:

   attribute value.x    score.x value.y   score.y value.x    score.x value.y   score.y value     score
1          1      1A 0.75600171      2A 0.9240333      1A 0.75600171      2A 0.9240333    5A 0.9930473
2          1      1A 0.75600171      2A 0.9240333      1A 0.75600171      2A 0.9240333    5B 0.7038961
3          1      1A 0.75600171      2A 0.9240333      1A 0.75600171      2A 0.9240333    5A 0.9930473
4          1      1A 0.75600171      2A 0.9240333      1A 0.75600171      2A 0.9240333    5B 0.7038961
...
Run Code Online (Sandbox Code Playgroud)

前两个表显示两次,这是不希望的。但是,当我在前4个表上尝试时,它可以完美地工作:

Reduce(function(x, y) {full_join(x, y, by=c('attribute'='attribute'))}, dfList[1:4])

   attribute value.x    score.x value.y   score.y value.x   score.x value.y   score.y
1          1      1A 0.75600171      2A 0.9240333      3A 0.7876383      4A 0.7599897
2          1      1A 0.75600171      2A 0.9240333      3A 0.7876383      4B 0.2520584
3          1      1A 0.75600171      2A 0.9240333      3B 0.8857657      4A 0.7599897
4          1      1A 0.75600171      2A 0.9240333      3B 0.8857657      4B 0.2520584 
...
Run Code Online (Sandbox Code Playgroud)

我哪里做错了?

我在Ubuntu 14.04上使用dplyr带有R版本的0.4.33.2.4

ins*_*ven 2

我可以在我的机器上为您重现损坏的结果dfList。在我看来,我已经找到了为什么会发生这种情况。

require(dplyr)

adf <- data.frame(c1 = 7, c1 = 8, jv = 1, check.names = F)
bdf <- data.frame(d1 = 1:3, d2 = letters[1:3], jv = 1)
cdf <- data.frame(v1.x = 1:3, v2 = letters[1:3], jv = 1)
ddf <- data.frame(v2 = 4:5, v2.x = letters[4:5], jv = 1)

full_join(adf, bdf, by = "jv") 

  c1 c1 jv d1 d2
1  7  7  1  1  a
2  7  7  1  2  b
3  7  7  1  3  c
Run Code Online (Sandbox Code Playgroud)

我们可以注意到,重复的列名会adf导致连接结果错误。当我们在 的帮助下应用多个连接的链时Reduce,会自动重命名重复的列名称(默认情况下会添加.x.y)。这可能会导致产生另一个重复的名称(与它想要避免的事情相反)。

full_join(cdf, ddf, by = "jv")

  v1.x v2.x jv v2.y v2.x
1    1    a  1    4    d
2    1    a  1    5    e
3    2    b  1    4    d
4    2    b  1    5    e
5    3    c  1    4    d
6    3    c  1    5    e
Run Code Online (Sandbox Code Playgroud)

data.frames在这里,我们在不同的- 列中有一个重复的名称v2,在应用后缀 - 后被另一个重复的名称替换v2.x

因此,为了使事情顺利进行,我们应该关心data.frame我们要加入的列的唯一名称。

我尝试了几种方法来获得所需的结果,并想展示它们是什么。

  • 基本 R 解决方案使用merge,它是为了速度比较而设计的。
  • full_join使用fromdplyr包的方法
  • data.table使用顺序smerge的解决方案dt
  • 一个基于 的tidyr函数unnest
  • 另一个data.table解决方案,首先生成目标结果长度的密钥表(在 的帮助下CJ),然后进行几个左连接
  • 与之前相同,但使用on参数进行连接而不是设置键

require(data.table)
require(dplyr)
require(tidyr)
require(stringi)
require(microbenchmark)

expand.grid.df_base <- function(...) {
  dfList <- list(...)
  if (length(dfList) == 1) dfList <- dfList[[1]]

  if (is.null(names(dfList))) names(dfList) <- paste0("df", 1:length(dfList))

  lapply(1:length(dfList), function(i) 
    data.frame(dfN = i, colN = 1:length(dfList[[i]]), 
               dfname = names(dfList)[i], colname = names(dfList[[i]]),
               stringsAsFactors = F)) %>% bind_rows %>%
    mutate(dum_names = stri_rand_strings(nrow(.), 12)) %>% rowwise %>% 
    mutate(out_names = paste(dfname, colname, sep = ".")) %>% ungroup -> manage_names

  for (i in 1:nrow(manage_names)) names(dfList[[manage_names$dfN[i]]])[manage_names$colN[i]] <- manage_names$dum_names[i]

  Reduce(function(x, y) merge(x, y, by = NULL), dfList) %>% setNames(manage_names$out_names)
}

expand.grid.df_dplyr <- function(...) {
  dfList <- list(...)
  if (length(dfList) == 1) dfList <- dfList[[1]]

  if (is.null(names(dfList))) names(dfList) <- paste0("df", 1:length(dfList))

  lapply(1:length(dfList), function(i) 
    data.frame(dfN = i, colN = 1:length(dfList[[i]]), 
               dfname = names(dfList)[i], colname = names(dfList[[i]]),
               stringsAsFactors = F)) %>% bind_rows %>%
    mutate(dum_names = stri_rand_strings(nrow(.), 12)) %>% rowwise %>% 
    mutate(out_names = paste(dfname, colname, sep = ".")) %>% ungroup -> manage_names

  for (i in 1:nrow(manage_names)) names(dfList[[manage_names$dfN[i]]])[manage_names$colN[i]] <- manage_names$dum_names[i]

  joinvar <- stri_rand_strings(1, 12)

  Reduce(function(x, y) {
    mutate_def <- list(1L)
    names(mutate_def) <- joinvar

    full_join(x %>% mutate_(.dots = mutate_def), y %>% mutate_(.dots = mutate_def), by = joinvar)
  }, dfList) %>% select(-contains(joinvar)) %>% setNames(manage_names$out_names) %>% tbl_df
}

expand.grid.dt <- function(...) {
  dtList <- list(...)
  if (length(dtList) == 1) dtList <- dtList[[1]]

  if (!all(sapply(dtList, is.data.table))) dtList <- lapply(dtList, as.data.table)

  if (is.null(names(dtList))) setnames(dtList, paste0("dt", 1:length(dtList)))

  lapply(1:length(dtList), function(i) 
    data.frame(dfN = i, colN = 1:length(dtList[[i]]), 
               dfname = names(dtList)[i], colname = names(dtList[[i]]),
               stringsAsFactors = F)) %>% bind_rows %>%
    mutate(dum_names = stri_rand_strings(nrow(.), 12)) %>% rowwise %>% 
    mutate(out_names = paste(dfname, colname, sep = ".")) %>% ungroup -> manage_names

  for (i in 1:nrow(manage_names)) setnames(dtList[[manage_names$dfN[i]]], old = manage_names$colN[i], new = manage_names$dum_names[i])

  joinvar <- stri_rand_strings(1, 12)

  setnames(Reduce(function(x, y) merge(copy(x)[,(joinvar) := 1], copy(y)[,(joinvar) := 1], 
                                       by = joinvar, all = T, allow.cartesian = T), dtList)[,(joinvar) := NULL],
           manage_names$out_names)[]
}

expand.grid.df_tidyr <- function(...) {
  dfList <- list(...)
  if (length(dfList) == 1) dfList <- dfList[[1]]

  if (is.null(names(dfList))) names(dfList) <- paste0("df", 1:length(dfList))

  lapply(1:length(dfList), function(i) 
    data.frame(dfN = i, colN = 1:length(dfList[[i]]), 
               dfname = names(dfList)[i], colname = names(dfList[[i]]),
               stringsAsFactors = F)) %>% bind_rows %>%
    mutate(dum_names = stri_rand_strings(nrow(.), 12)) %>% rowwise %>% 
    mutate(out_names = paste(dfname, colname, sep = ".")) %>% ungroup -> manage_names

  for (i in 1:nrow(manage_names)) names(dfList[[manage_names$dfN[i]]])[manage_names$colN[i]] <- manage_names$dum_names[i]

  Reduce(function(x, y) x %>% rowwise %>% mutate(dfcol = list(y)) %>% ungroup %>% unnest(dfcol), dfList) %>% 
    setNames(manage_names$out_names) %>% tbl_df
}

expand.grid.dt2 <- function(...) {
  dtList <- list(...)
  if (length(dtList) == 1) dtList <- dtList[[1]]

  dum_names <- stri_rand_strings(length(dtList), 12)

  dtList <- lapply(1:length(dtList), function(i) 
    setkeyv(as.data.table(dtList[[i]])[, (dum_names[i]) := .I], dum_names[i]))

  Reduce(function(result, dt) setkeyv(result, names(result)[1])[dt][, (names(result)[1]) := NULL],
         dtList,
         setnames(do.call(CJ, c(sapply(dtList, function(df) seq_len(nrow(df))), list(sorted = F))), dum_names))[]
}

expand.grid.dt3 <- function(...) {

  dtList <- list(...)
  if (length(dtList) == 1) dtList <- dtList[[1]]

  dum_names <- stri_rand_strings(length(dtList), 12)

  dtList <- lapply(1:length(dtList), function(i) as.data.table(dtList[[i]])[, (dum_names[i]) := .I])

  Reduce(function(result, dt) result[dt, on = names(result)[1]][, (names(result)[1]) := NULL],
         dtList,
         setnames(do.call(CJ, c(sapply(dtList, function(df) seq_len(nrow(df))), list(sorted = F))), dum_names))[]
}
Run Code Online (Sandbox Code Playgroud)

现在让我们创建 s 列表data.frame来测试此函数。

set.seed(1)
bigdfList <- data.frame(type = sample(letters[1:10], 50, T),
                        categ = sample(LETTERS[1:10], 50, T),
                        num = sample(100L:500L, 50, T),
                        val = rnorm(50)) %>% split(., .$type)

smalldfList <- data.frame(type = sample(letters[1:5], 50, T),
                          categ = sample(LETTERS[1:5], 50, T),
                          num = sample(100L:500L, 50, T),
                          val = rnorm(50)) %>% split(., .$type)
Run Code Online (Sandbox Code Playgroud)

的扩展连接生成维度为 和 的smalldfList表-占用 1230.5 MB RAM。[60,480 x 20]bigdfList[6,451,200 x 40]

从...开始smalldfList

microbenchmark(expand.grid.df_base(smalldfList), expand.grid.df_dplyr(smalldfList), 
               expand.grid.dt(smalldfList), expand.grid.df_tidyr(smalldfList),
               expand.grid.dt2(smalldfList), expand.grid.dt3(smalldfList), times = 10)

Unit: milliseconds
                              expr       min        lq      mean    median        uq       max neval cld
  expand.grid.df_base(smalldfList) 178.36192 188.54955 201.28729 198.79644 209.86934 229.85360    10  b 
 expand.grid.df_dplyr(smalldfList)  16.04555  16.91327  18.91094  17.64907  18.45307  29.58192    10 a  
       expand.grid.dt(smalldfList)  20.33188  21.42275  26.30034  23.22873  31.66666  39.37922    10 a  
 expand.grid.df_tidyr(smalldfList) 722.06572 738.02188 801.41820 792.23725 859.96186 905.99190    10   c
      expand.grid.dt2(smalldfList)  32.22650  33.68353  36.89386  36.39713  37.39182  48.93550    10 a  
      expand.grid.dt3(smalldfList)  29.13399  30.69299  34.51265  34.03198  37.48651  41.73543    10 a  
Run Code Online (Sandbox Code Playgroud)

因此,tidyr解决方案根本不是这里的选择,基础merge也很慢。其他4个功能就bigdfList表现出以下效率。

microbenchmark(expand.grid.df_dplyr(bigdfList), expand.grid.dt(bigdfList), 
               expand.grid.dt2(bigdfList), expand.grid.dt3(bigdfList), times = 10)

Unit: seconds
                            expr       min        lq      mean    median        uq       max neval  cld
 expand.grid.df_dplyr(bigdfList)  1.326336  1.354706  1.456805  1.449781  1.481836  1.703158    10 a   
       expand.grid.dt(bigdfList)  1.763174  1.820004  1.894813  1.893910  1.939879  2.127097    10  b  
      expand.grid.dt2(bigdfList) 14.164731 14.332872 14.452933 14.452221 14.551982 14.740852    10    d
      expand.grid.dt3(bigdfList) 10.589517 10.828548 11.104010 11.021519 11.368172 11.976976    10   c 
Run Code Online (Sandbox Code Playgroud)

并且该dplyr::full_join解决方案具有最佳结果!

dplyr也许,这是真正比 更好的选项之一data.table,也许是我缺乏data.table知识,这阻止了我制作一个真正快速的函数:-)