是否有一种 RAM 有效的方法来计算补集的中位数?

Jul*_*ann 11 ram r median data.table

我正在寻找一种 RAM 有效的方法来在 data.table 的帮助下计算补集的中位数。

对于来自不同组的一组观察结果,我对“其他组”中位数的实现感兴趣。即,如果一个 data.table 有一个值列和一个分组列,我想为每个组计算除当前之外的所有其他组中值的中位数。例如,对于第 1 组,我们计算除属于第 1 组的值以外的所有值的中位数,依此类推。

一个具体的例子 data.table

dt <- data.table(value = c(1,2,3,4,5), groupId = c(1,1,2,2,2))
dt
#    value groupId
# 1:     1       1
# 2:     2       1
# 3:     3       2
# 4:     4       2
# 5:     5       2
Run Code Online (Sandbox Code Playgroud)

我希望将medianOfAllTheOtherGroups 定义为第2 组的1.5 并定义第1 组的4,对同一data.table 中的每个条目重复:

dt <- data.table(value = c(1,2,3,4,5), groupId = c(1,1,2,2,2), medianOfAllTheOtherGroups = c(4, 4, 1.5, 1.5, 1.5))

dt
#    value groupId medianOfAllTheOtherGroups
# 1:     1       1                       4.0 # median of all groups _except_ 1
# 2:     2       1                       4.0
# 3:     3       2                       1.5 # median of all groups _except_ 2
# 4:     4       2                       1.5  
# 5:     5       2                       1.5
Run Code Online (Sandbox Code Playgroud)

为了只计算每个组的中位数而不是每个观察值,我们使用循环进行了实现。当前的完整实现对于作为输入的小型 data.tables 效果很好,但是对于较大的数据集会消耗大量 RAM,并且循环中调用的中位数是瓶颈(注意:对于实际用例,我们有一个 dt 为 3.000。 000 行和 100.000 组)。我在改善 RAM 消耗方面的工作很少。对于我在下面提供的最小示例,专家可以在这里帮助改进 RAM 吗?

最小示例:

library(data.table)
set.seed(1)
numberOfGroups <- 10
numberOfValuesPerGroup <- 100

# Data table with column
# groupIds - Ids for the groups available
# value - value we want to calculate the median over
# includeOnly - boolean that indicates which example should get a "group specific" median
dt <-
  data.table(
    groupId = as.character(rep(1:numberOfGroups, each = numberOfValuesPerGroup)),
    value = round(runif(n = numberOfGroups * numberOfValuesPerGroup), 4)
  )

# calculate the median from all observations for those groups that do not 
# require a separate treatment
medianOfAllGroups <-  median(dt$value)
dt$medianOfAllTheOtherGroups <- medianOfAllGroups


# generate extra data.table to collect results for selected groups
includedGroups <-  dt[, unique(groupId)]
dt_otherGroups <- 
  data.table(groupId = includedGroups,
             medianOfAllTheOtherGroups =  as.numeric(NA)
  )

# loop over all selected groups and calculate the median from all observations
# except of those that belong to this group
for (id in includedGroups){
  dt_otherGroups[groupId == id, 
                 medianOfAllTheOtherGroups := median(dt[groupId != id, value])]
}

# merge subset data to overall data.table
dt[dt_otherGroups, medianOfAllTheOtherGroups := i.medianOfAllTheOtherGroups, 
   on = c("groupId")]

Run Code Online (Sandbox Code Playgroud)

PS:这里是 10 个组的示例输出,每个组有 100 个观察值:

dt
#      groupId  value medianOfAllTheOtherGroups
#   1:       1 0.2655                   0.48325
#   2:       1 0.3721                   0.48325
#   3:       1 0.5729                   0.48325
#   4:       1 0.9082                   0.48325
#   5:       1 0.2017                   0.48325
# ---
#  996:      10 0.7768                   0.48590
#  997:      10 0.6359                   0.48590
#  998:      10 0.2821                   0.48590
#  999:      10 0.1913                   0.48590
# 1000:      10 0.2655                   0.48590
Run Code Online (Sandbox Code Playgroud)

最小示例的不同设置的一些数字(在具有 16Gb RAM 的 Mac Book Pro 上测试):

组数 每组值的数量 内存 (GB) 运行时间(秒)
500 50 0.48 1.47
5000 50 39.00 58.00
50 5000 0.42 0.65

所有内存值都是从 profvis 的输出中提取的,请参见此处最小示例的示例屏幕截图: 专业输出

Col*_*ole 5

中位数是已排序数据集的中点。对于数据集中的奇数个值,中位数就是中间的数字。对于数据集中的偶数个值,中位数是最靠近中间的两个数字的平均值。

为了进行演示,请考虑 1:8 的简单向量

1 | 2 | 3 |** 4 | 5 **| 6 | 7 | 8

在本例中,我们的中点是 4.5。因为这是一个非常简单的例子,中位数本身就是 4.5

现在考虑分组,其中一个分组是向量的第一个值。也就是说,我们的组只有 1。我们知道这会将我们的中位数向右移动(即更大),因为我们删除了分布的低值。我们的新分布是 2:8,中位数现在是 5。

2 | 3 | 4 | *5* | 6 | 7 | 8

只有当我们能够确定这些转变之间的关系时,这才有意义。具体来说,我们最初的中点是 4.5。基于原始向量的新中点是 5。

让我们演示一个由 1、3 和 7 组成的组的较大混合。在本例中,我们有 2 个值低于原始中点,一个值高于原始中点。我们的新中位数是 5:

2 | 4 | ** 5 ** | 6 | 8

因此,根据经验,我们确定,从分布中删除较小的数字会使我们的中点索引移动 0.5,从分布中删除较大的数字会使我们的中点索引移动 -0.5。还有一些其他规定:

我们需要确保我们的分组索引不在新的 mid_point 计算中。考虑一组 1、2 和 5。根据我的数学计算,我们将根据(2 below - 1 above) / 25 的新中点向上移动 0.5。这是错误的,因为 5 已经用完!我们需要考虑到这一点。

3 | 4 | ** 6 ** | 7 | 8

同样,随着中点的移动,我们还需要回头验证我们的排名值是否仍然一致。在 1:20 的序列中,考虑一组c(1:9, 11). 虽然 11 最初高于原始中点 10.5,但它并不高于移动后的中点(9 below - 1 above ) / 214.5。但我们的实际中位数是 15.5,因为 11 现在低于新的中点。

10 | 12 | 13 | 14 | ** 15 | 16 **| 17 | 18 | 19 | 20

TL:DR 代码是什么?

上面的所有示例中,分组的排名向量都在中通过特殊符号给出I,假设我们这样做了setorder()。如果我们进行与上面相同的数学计算,我们就不必浪费时间对数据集进行子集化。相反,我们可以根据从分布中删除的内容来确定新索引应该是什么。


setorder(dt, value)  

nr = nrow(dt)
is_even = nr %% 2L == 0L
mid_point = (nr + 1L) / 2L

dt[, medianOfAllTheOtherGroups :=
     {
       below = sum(.I < mid_point)
     is_midpoint = is_even && below && (.I[below] + 1L == mid_point)
     
     above = .N - below - is_midpoint
     new_midpoint = (below - above) / 2L + mid_point
     ## TODO turn this into a loop incase there are multiple values that this is true
     if (new_midpoint > mid_point && above &&.I[below + 1] < new_midpoint) { ## check to make sure that none of the indices were above
       below = below - 1L
       new_midpoint = new_midpoint + 1L
     } else if (new_midpoint < mid_point && below && .I[below] > new_midpoint) {
       below = below + 1L
       new_midpoint = new_midpoint - 1L
     }
     if (((nr - .N + 1L) %% 2L) == 0L) {
       dt$value[new_midpoint]
     } else {
       ##TODO turn this into a loop in case there are multiple values that this is true for.
       default_inds = as.integer(new_midpoint + c(-0.5, 0.5))
       if (below) {
         if (.I[below] == default_inds[1L])
           default_inds[1L] = .I[below] - 1L
       }
       if (above) {
         if (.I[below + 1L + is_midpoint] == default_inds[2L])
           default_inds[2L] = .I[below + 1L] + 1L
       }
       mean(dt$value[default_inds])
     }
     }
   , by = groupId]

Run Code Online (Sandbox Code Playgroud)

表现

这是使用bench::mark它检查所有结果是否相等。对于 Henrik 和我的解决方案,我确实将结果重新排序回原始分组,以便它们全部相等。

请注意,虽然这种(复杂的)算法是最有效的,但我确实想强调,其中大多数算法可能不会达到极端的RAM 使用峰值。其他答案必须子集 5,000 次才能分配长度为 249,950 的向量来计算新的中位数。仅分配时每个循环就大约有 2 MB(例如,总共 10 GB)。

# A tibble: 6 x 13
  expression            min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result          memory        time    gc      
  <bch:expr>       <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm> <list>          <list>        <list>  <list>  
1 cole              225.7ms  271.8ms    3.68      6.34MB    
2 henrik_smart_med    17.7s    17.7s    0.0564   23.29GB    
3 henrik_base_med      1.6m     1.6m    0.0104   41.91GB    
4 henrik_fmed         55.9s    55.9s    0.0179   32.61GB    
5 christian_lookup    54.7s    54.7s    0.0183   51.39GB    
6 talat_unlist        35.9s    35.9s    0.0279   19.02GB     
Run Code Online (Sandbox Code Playgroud) 完整档案代码
library(data.table)
library(collapse)
set.seed(76)
numberOfGroups <- 5000
numberOfValuesPerGroup <- 50

dt <-
  data.table(
    groupId = (rep(1:numberOfGroups, each = numberOfValuesPerGroup)),
    value = round(runif(n = numberOfGroups * numberOfValuesPerGroup, 0, 10), 4)
  )

## this is largely instantaneous.
dt[ , ri := .I]

bench::mark( cole = {
  setorder(dt, value)
  
  nr = nrow(dt)
  is_even = nr %% 2L == 0L
  mid_point = (nr + 1L) / 2L
  
  dt[, medianOfAllTheOtherGroups :=
       {
         below = sum(.I < mid_point)
         is_midpoint = is_even && below && (.I[below] + 1L == mid_point)
         
         above = .N - below - is_midpoint
         new_midpoint = (below - above) / 2L + mid_point
         ## TODO turn this into a loop incase there are multiple values that this is true
         if (new_midpoint > mid_point && above &&.I[below + 1] < new_midpoint) { ## check to make sure that none of the indices were above
           below = below - 1L
           new_midpoint = new_midpoint + 1L
         } else if (new_midpoint < mid_point && below && .I[below] > new_midpoint) {
           below = below + 1L
           new_midpoint = new_midpoint - 1L
         }
         if (((nr - .N + 1L) %% 2L) == 0L) {
           as.numeric(dt$value[new_midpoint])
         } else {
           ##TODO turn this into a loop in case there are multiple values that this is true for.
           default_inds = as.integer(new_midpoint + c(-0.5, 0.5))
           if (below) {
             if (.I[below] == default_inds[1L])
               default_inds[1L] = .I[below] - 1L
           }
           if (above) {
             if (.I[below + 1L + is_midpoint] == default_inds[2L])
               default_inds[2L] = .I[below + 1L] + 1L
           }
           mean(dt$value[default_inds])
         }
       }
     , by = groupId]
  
  setorder(dt, ri)

},
henrik_smart_med = {
  
  # number of rows in original data    
  nr = nrow(dt)
  
  # order by value
  setorder(dt, value)
  
  dt[ , medianOfAllTheOtherGroups := {
    
    # length of "other"
    n = nr - .N
    
    # ripped from median
    half = (n + 1L) %/% 2L
    if (n %% 2L == 1L) dt$value[-.I][half]
    else mean(dt$value[-.I][half + 0L:1L])
    
  }, by = groupId]
  setorder(dt, ri)
},
henrik_base_med = {
  dt[ , med := median(dt$value[-.I]), by = groupId]
},
henrik_fmed = {
  dt[ , med := fmedian(dt$value[-.I]), by = groupId]
}, 
christian_lookup = {
  nrows <- dt[, .N]
  dt_match <- dt[, .(nrows_other = nrows- .N), by = .(groupId_match = groupId)]
  dt_match[, odd := nrows_other %% 2]
  dt_match[, idx1 := ceiling(nrows_other/2)]
  dt_match[, idx2 := ifelse(odd, idx1, idx1+1)]
  
  setkey(dt, value)
  dt_match[, medianOfAllTheOtherGroups := dt[groupId != groupId_match][c(idx1, idx2), sum(value)/2], by = groupId_match]
  dt[dt_match, medianOfAllTheOtherGroups := i.medianOfAllTheOtherGroups, 
     on = c(groupId = "groupId_match")]
},
talat_unlist = {
  d2 = dt[, .(value = list(value)), keyby = .(groupId)]
  setkey(dt, groupId)
  dt[, medianOfAllTheOtherGroups := 
       fmedian(d2[-.GRP, unlist(value, use.names = FALSE, recursive = FALSE)]), 
     by = .(groupId)]  
})
Run Code Online (Sandbox Code Playgroud)


Chr*_*rck 3

精确结果的方法:中位数是排序向量的“中间”值。(或者偶数长度向量的两个中间值的平均值)如果我们知道其他向量的排序长度,我们可以直接查找相应的向量元素索引来获取中位数,从而避免实际计算中位数 n*groupId 次:

library(data.table)
set.seed(1)
numberOfGroups <- 5000
numberOfValuesPerGroup <- 50

dt <-
  data.table(
    groupId = as.character(rep(1:numberOfGroups, each = numberOfValuesPerGroup)),
    value = round(runif(n = numberOfGroups * numberOfValuesPerGroup), 4)
  )

# group count match table + idx position for median of others
nrows <- dt[, .N]
dt_match <- dt[, .(nrows_other = nrows- .N), by = .(groupId_match = groupId)]
dt_match[, odd := nrows_other %% 2]
dt_match[, idx1 := ceiling(nrows_other/2)]
dt_match[, idx2 := ifelse(odd, idx1, idx1+1)]

setkey(dt, value)
dt_match[, medianOfAllTheOtherGroups := dt[groupId != groupId_match][c(idx1, idx2), sum(value)/2], by = groupId_match]
dt[dt_match, medianOfAllTheOtherGroups := i.medianOfAllTheOtherGroups, 
 on = c(groupId = "groupId_match")]
Run Code Online (Sandbox Code Playgroud)

我猜想,可能还有更多类似 data.table 的方法可以进一步提高性能。

numberOfGroups = 5000 且 numberOfValuesPerGroup = 50 的内存/运行时:20GB,27000ms