使用 dplyr 语法在 duckdb 中生成确定性随机数

Ash*_*wad 5 r dplyr apache-arrow duckdb

如何将 duckdb 的setseed()函数(请参阅参考文档)与 dplyr 语法一起使用,以确保下面的分析是可重现的?

\n
# dplyr version 1.1.1\n# arrow version 11.0.0.3\n# duckdb 0.7.1.1\nout_dir <- tempfile()\narrow::write_dataset(mtcars, out_dir, partitioning = "cyl")\n\nmtcars_ds <- arrow::open_dataset(out_dir)\n\nmtcars_smry <- mtcars_ds |>\n  arrow::to_duckdb() |>\n  dplyr::mutate(\n    fold = ceiling(3 * random())\n  ) |>\n  dplyr::summarize(\n    avg_hp = mean(hp),\n    .by = c(cyl, fold)\n  )\n\nmtcars_smry |>\n  dplyr::collect()\n#> Warning: Missing values are always removed in SQL aggregation functions.\n#> Use `na.rm = TRUE` to silence this warning\n#> This warning is displayed once every 8 hours.\n#> # A tibble: 9 \xc3\x97 3\n#>     cyl  fold avg_hp\n#>   <int> <dbl>  <dbl>\n#> 1     4     1   92  \n#> 2     4     3   82.3\n#> 3     4     2   74.5\n#> 4     8     2  183. \n#> 5     8     3  210  \n#> 6     8     1  300. \n#> 7     6     3  110  \n#> 8     6     1  117  \n#> 9     6     2  175\n
Run Code Online (Sandbox Code Playgroud)\n

创建于 2023-08-27,使用reprex v2.0.2

\n

r2e*_*ans 4

setseed()需要在查询内;虽然它不一定以某种方式完全有意义(因为它返回 null/ NA),但它至少是清楚的。我们可以将其包含在自己的“查询”中。

快速助手功能,方便使用:

use_setseed <- function(tab, seed = 0.5) {
  ign <- tab |>
    summarize(a = setseed(seed)) |>
    head(n = 1) |>
    collect()
  invisible(NULL)
}
Run Code Online (Sandbox Code Playgroud)

关于这一点的一个重要注意事项是,它必须“实现”查询(通常是collect()它)才能setseed()实际执行调用。由于我们需要实现它,但我们不需要它的任何数据,所以我通过“汇总”(一行一列)来减少传回的数据,然后收集它,然后用invisible().

此外,这正在致力于建立更直接的连接,

duck <- DBI::dbConnect(duckdb::duckdb())
DBI::dbWriteTable(duck, "mtcars", mtcars)
mtcars_tbl <- tbl(duck, "mtcars")
Run Code Online (Sandbox Code Playgroud)

use_setseed从这里开始,我们只需要在随机查询之前立即调用即可。

use_setseed(mtcars_tbl)
mtcars_tbl |>
  dplyr::mutate(fold = ceiling(3 * random())) |>
  dplyr::summarize(avg_hp = mean(hp), .by = c(cyl, fold) )
# # Source:   SQL [9 x 3]
# # Database: DuckDB 0.8.1 [r2@Linux 6.2.0-27-generic:R 4.2.3/:memory:]
#     cyl  fold avg_hp
#   <dbl> <dbl>  <dbl>
# 1     6     1  110  
# 2     4     1   83.5
# 3     8     3  210  
# 4     6     3  114  
# 5     4     3   79.7
# 6     6     2  149  
# 7     8     1  174  
# 8     8     2  252. 
# 9     4     2   97  

# validation
use_setseed(mtcars_tbl)
res1 <- mtcars_tbl |>
  dplyr::mutate(fold = ceiling(3 * random())) |> 
  dplyr::summarize(avg_hp = mean(hp), .by = c(cyl, fold) ) |>
  dplyr::collect()
resn <- replicate(10, {
  use_setseed(mtcars_tbl)
  mtcars_tbl |>
    dplyr::mutate(fold = ceiling(3 * random())) |> 
    dplyr::summarize(avg_hp = mean(hp), .by = c(cyl, fold) ) |> 
    dplyr::collect()
}, simplify=FALSE)
sapply(resn, identical, res1)
#  [1] TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE
Run Code Online (Sandbox Code Playgroud)

同样,如果您有一个 duckdb 连接对象,我们可以使用此版本的函数来进一步减少带宽:

use_setseed2 <- function(con, seed=0.5) {
  DBI::dbExecute(con, "select setseed(?) as ign", params = list(seed))
  invisible(NULL)
}
Run Code Online (Sandbox Code Playgroud)

并使用 -connection 对象调用它duckdb,如下所示

use_setseed2(duck) # note 'duck' and not 'mtcars_tbl'
mtcars_tbl |>
  dplyr::mutate(fold = ceiling(3 * random())) |> 
  dplyr::summarize(avg_hp = mean(hp), .by = c(cyl, fold) )
# same as above

# validation
use_setseed2(duck)
res1 <- mtcars_tbl |>
  dplyr::mutate(fold = ceiling(3 * random())) |> 
  dplyr::summarize(avg_hp = mean(hp), .by = c(cyl, fold) ) |>
  dplyr::collect()
resn <- replicate(10, {
  use_setseed2(duck)
  mtcars_tbl |>
    dplyr::mutate(fold = ceiling(3 * random())) |> 
    dplyr::summarize(avg_hp = mean(hp), .by = c(cyl, fold) ) |> 
    dplyr::collect()
}, simplify=FALSE)
sapply(resn, identical, res1)
#  [1] TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE
Run Code Online (Sandbox Code Playgroud)