Jul*_*ann 11 ram r median data.table
我正在寻找一种 RAM 有效的方法来在 data.table 的帮助下计算补集的中位数。
对于来自不同组的一组观察结果,我对“其他组”的中位数的实现感兴趣。即,如果一个 data.table 有一个值列和一个分组列,我想为每个组计算除当前组之外的所有其他组中值的中位数。例如,对于第 1 组,我们计算除属于第 1 组的值以外的所有值的中位数,依此类推。
一个具体的例子 data.table
dt <- data.table(value = c(1,2,3,4,5), groupId = c(1,1,2,2,2))
dt
# value groupId
# 1: 1 1
# 2: 2 1
# 3: 3 2
# 4: 4 2
# 5: 5 2
Run Code Online (Sandbox Code Playgroud)
我希望将medianOfAllTheOtherGroups 定义为第2 组的1.5 并定义第1 组的4,对同一data.table 中的每个条目重复:
dt <- data.table(value = c(1,2,3,4,5), groupId = c(1,1,2,2,2), medianOfAllTheOtherGroups = c(4, 4, 1.5, 1.5, 1.5))
dt
# value groupId medianOfAllTheOtherGroups
# 1: 1 1 4.0 # median of all groups _except_ 1
# 2: 2 1 4.0
# 3: 3 2 1.5 # median of all groups _except_ 2
# 4: 4 2 1.5
# 5: 5 2 1.5
Run Code Online (Sandbox Code Playgroud)
为了只计算每个组的中位数而不是每个观察值,我们使用循环进行了实现。当前的完整实现对于作为输入的小型 data.tables 效果很好,但是对于较大的数据集会消耗大量 RAM,并且循环中调用的中位数是瓶颈(注意:对于实际用例,我们有一个 dt 为 3.000。 000 行和 100.000 组)。我在改善 RAM 消耗方面的工作很少。对于我在下面提供的最小示例,专家可以在这里帮助改进 RAM 吗?
最小示例:
library(data.table)
set.seed(1)
numberOfGroups <- 10
numberOfValuesPerGroup <- 100
# Data table with column
# groupIds - Ids for the groups available
# value - value we want to calculate the median over
# includeOnly - boolean that indicates which example should get a "group specific" median
dt <-
data.table(
groupId = as.character(rep(1:numberOfGroups, each = numberOfValuesPerGroup)),
value = round(runif(n = numberOfGroups * numberOfValuesPerGroup), 4)
)
# calculate the median from all observations for those groups that do not
# require a separate treatment
medianOfAllGroups <- median(dt$value)
dt$medianOfAllTheOtherGroups <- medianOfAllGroups
# generate extra data.table to collect results for selected groups
includedGroups <- dt[, unique(groupId)]
dt_otherGroups <-
data.table(groupId = includedGroups,
medianOfAllTheOtherGroups = as.numeric(NA)
)
# loop over all selected groups and calculate the median from all observations
# except of those that belong to this group
for (id in includedGroups){
dt_otherGroups[groupId == id,
medianOfAllTheOtherGroups := median(dt[groupId != id, value])]
}
# merge subset data to overall data.table
dt[dt_otherGroups, medianOfAllTheOtherGroups := i.medianOfAllTheOtherGroups,
on = c("groupId")]
Run Code Online (Sandbox Code Playgroud)
PS:这里是 10 个组的示例输出,每个组有 100 个观察值:
dt
# groupId value medianOfAllTheOtherGroups
# 1: 1 0.2655 0.48325
# 2: 1 0.3721 0.48325
# 3: 1 0.5729 0.48325
# 4: 1 0.9082 0.48325
# 5: 1 0.2017 0.48325
# ---
# 996: 10 0.7768 0.48590
# 997: 10 0.6359 0.48590
# 998: 10 0.2821 0.48590
# 999: 10 0.1913 0.48590
# 1000: 10 0.2655 0.48590
Run Code Online (Sandbox Code Playgroud)
最小示例的不同设置的一些数字(在具有 16Gb RAM 的 Mac Book Pro 上测试):
组数 | 每组值的数量 | 内存 (GB) | 运行时间(秒) |
---|---|---|---|
500 | 50 | 0.48 | 1.47 |
5000 | 50 | 39.00 | 58.00 |
50 | 5000 | 0.42 | 0.65 |
中位数是已排序数据集的中点。对于数据集中的奇数个值,中位数就是中间的数字。对于数据集中的偶数个值,中位数是最靠近中间的两个数字的平均值。
为了进行演示,请考虑 1:8 的简单向量
1 | 2 | 3 |** 4 | 5 **| 6 | 7 | 8
在本例中,我们的中点是 4.5。因为这是一个非常简单的例子,中位数本身就是 4.5
现在考虑分组,其中一个分组是向量的第一个值。也就是说,我们的组只有 1。我们知道这会将我们的中位数向右移动(即更大),因为我们删除了分布的低值。我们的新分布是 2:8,中位数现在是 5。
2 | 3 | 4 | *5* | 6 | 7 | 8
只有当我们能够确定这些转变之间的关系时,这才有意义。具体来说,我们最初的中点是 4.5。基于原始向量的新中点是 5。
让我们演示一个由 1、3 和 7 组成的组的较大混合。在本例中,我们有 2 个值低于原始中点,一个值高于原始中点。我们的新中位数是 5:
2 | 4 | ** 5 ** | 6 | 8
因此,根据经验,我们确定,从分布中删除较小的数字会使我们的中点索引移动 0.5,从分布中删除较大的数字会使我们的中点索引移动 -0.5。还有一些其他规定:
我们需要确保我们的分组索引不在新的 mid_point 计算中。考虑一组 1、2 和 5。根据我的数学计算,我们将根据(2 below - 1 above) / 2
5 的新中点向上移动 0.5。这是错误的,因为 5 已经用完!我们需要考虑到这一点。
3 | 4 | ** 6 ** | 7 | 8
同样,随着中点的移动,我们还需要回头验证我们的排名值是否仍然一致。在 1:20 的序列中,考虑一组c(1:9, 11)
. 虽然 11 最初高于原始中点 10.5,但它并不高于移动后的中点(9 below - 1 above ) / 2
14.5。但我们的实际中位数是 15.5,因为 11 现在低于新的中点。
10 | 12 | 13 | 14 | ** 15 | 16 **| 17 | 18 | 19 | 20
上面的所有示例中,分组的排名向量都在data.table中通过特殊符号给出I
,假设我们这样做了setorder()
。如果我们进行与上面相同的数学计算,我们就不必浪费时间对数据集进行子集化。相反,我们可以根据从分布中删除的内容来确定新索引应该是什么。
setorder(dt, value)
nr = nrow(dt)
is_even = nr %% 2L == 0L
mid_point = (nr + 1L) / 2L
dt[, medianOfAllTheOtherGroups :=
{
below = sum(.I < mid_point)
is_midpoint = is_even && below && (.I[below] + 1L == mid_point)
above = .N - below - is_midpoint
new_midpoint = (below - above) / 2L + mid_point
## TODO turn this into a loop incase there are multiple values that this is true
if (new_midpoint > mid_point && above &&.I[below + 1] < new_midpoint) { ## check to make sure that none of the indices were above
below = below - 1L
new_midpoint = new_midpoint + 1L
} else if (new_midpoint < mid_point && below && .I[below] > new_midpoint) {
below = below + 1L
new_midpoint = new_midpoint - 1L
}
if (((nr - .N + 1L) %% 2L) == 0L) {
dt$value[new_midpoint]
} else {
##TODO turn this into a loop in case there are multiple values that this is true for.
default_inds = as.integer(new_midpoint + c(-0.5, 0.5))
if (below) {
if (.I[below] == default_inds[1L])
default_inds[1L] = .I[below] - 1L
}
if (above) {
if (.I[below + 1L + is_midpoint] == default_inds[2L])
default_inds[2L] = .I[below + 1L] + 1L
}
mean(dt$value[default_inds])
}
}
, by = groupId]
Run Code Online (Sandbox Code Playgroud)
这是使用bench::mark
它检查所有结果是否相等。对于 Henrik 和我的解决方案,我确实将结果重新排序回原始分组,以便它们全部相等。
请注意,虽然这种(复杂的)算法是最有效的,但我确实想强调,其中大多数算法可能不会达到极端的RAM 使用峰值。其他答案必须子集 5,000 次才能分配长度为 249,950 的向量来计算新的中位数。仅分配时每个循环就大约有 2 MB(例如,总共 10 GB)。
# A tibble: 6 x 13
expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time result memory time gc
<bch:expr> <bch:tm> <bch:tm> <dbl> <bch:byt> <dbl> <int> <dbl> <bch:tm> <list> <list> <list> <list>
1 cole 225.7ms 271.8ms 3.68 6.34MB
2 henrik_smart_med 17.7s 17.7s 0.0564 23.29GB
3 henrik_base_med 1.6m 1.6m 0.0104 41.91GB
4 henrik_fmed 55.9s 55.9s 0.0179 32.61GB
5 christian_lookup 54.7s 54.7s 0.0183 51.39GB
6 talat_unlist 35.9s 35.9s 0.0279 19.02GB
Run Code Online (Sandbox Code Playgroud)
完整档案代码
library(data.table)
library(collapse)
set.seed(76)
numberOfGroups <- 5000
numberOfValuesPerGroup <- 50
dt <-
data.table(
groupId = (rep(1:numberOfGroups, each = numberOfValuesPerGroup)),
value = round(runif(n = numberOfGroups * numberOfValuesPerGroup, 0, 10), 4)
)
## this is largely instantaneous.
dt[ , ri := .I]
bench::mark( cole = {
setorder(dt, value)
nr = nrow(dt)
is_even = nr %% 2L == 0L
mid_point = (nr + 1L) / 2L
dt[, medianOfAllTheOtherGroups :=
{
below = sum(.I < mid_point)
is_midpoint = is_even && below && (.I[below] + 1L == mid_point)
above = .N - below - is_midpoint
new_midpoint = (below - above) / 2L + mid_point
## TODO turn this into a loop incase there are multiple values that this is true
if (new_midpoint > mid_point && above &&.I[below + 1] < new_midpoint) { ## check to make sure that none of the indices were above
below = below - 1L
new_midpoint = new_midpoint + 1L
} else if (new_midpoint < mid_point && below && .I[below] > new_midpoint) {
below = below + 1L
new_midpoint = new_midpoint - 1L
}
if (((nr - .N + 1L) %% 2L) == 0L) {
as.numeric(dt$value[new_midpoint])
} else {
##TODO turn this into a loop in case there are multiple values that this is true for.
default_inds = as.integer(new_midpoint + c(-0.5, 0.5))
if (below) {
if (.I[below] == default_inds[1L])
default_inds[1L] = .I[below] - 1L
}
if (above) {
if (.I[below + 1L + is_midpoint] == default_inds[2L])
default_inds[2L] = .I[below + 1L] + 1L
}
mean(dt$value[default_inds])
}
}
, by = groupId]
setorder(dt, ri)
},
henrik_smart_med = {
# number of rows in original data
nr = nrow(dt)
# order by value
setorder(dt, value)
dt[ , medianOfAllTheOtherGroups := {
# length of "other"
n = nr - .N
# ripped from median
half = (n + 1L) %/% 2L
if (n %% 2L == 1L) dt$value[-.I][half]
else mean(dt$value[-.I][half + 0L:1L])
}, by = groupId]
setorder(dt, ri)
},
henrik_base_med = {
dt[ , med := median(dt$value[-.I]), by = groupId]
},
henrik_fmed = {
dt[ , med := fmedian(dt$value[-.I]), by = groupId]
},
christian_lookup = {
nrows <- dt[, .N]
dt_match <- dt[, .(nrows_other = nrows- .N), by = .(groupId_match = groupId)]
dt_match[, odd := nrows_other %% 2]
dt_match[, idx1 := ceiling(nrows_other/2)]
dt_match[, idx2 := ifelse(odd, idx1, idx1+1)]
setkey(dt, value)
dt_match[, medianOfAllTheOtherGroups := dt[groupId != groupId_match][c(idx1, idx2), sum(value)/2], by = groupId_match]
dt[dt_match, medianOfAllTheOtherGroups := i.medianOfAllTheOtherGroups,
on = c(groupId = "groupId_match")]
},
talat_unlist = {
d2 = dt[, .(value = list(value)), keyby = .(groupId)]
setkey(dt, groupId)
dt[, medianOfAllTheOtherGroups :=
fmedian(d2[-.GRP, unlist(value, use.names = FALSE, recursive = FALSE)]),
by = .(groupId)]
})
Run Code Online (Sandbox Code Playgroud)
精确结果的方法:中位数是排序向量的“中间”值。(或者偶数长度向量的两个中间值的平均值)如果我们知道其他向量的排序长度,我们可以直接查找相应的向量元素索引来获取中位数,从而避免实际计算中位数 n*groupId 次:
library(data.table)
set.seed(1)
numberOfGroups <- 5000
numberOfValuesPerGroup <- 50
dt <-
data.table(
groupId = as.character(rep(1:numberOfGroups, each = numberOfValuesPerGroup)),
value = round(runif(n = numberOfGroups * numberOfValuesPerGroup), 4)
)
# group count match table + idx position for median of others
nrows <- dt[, .N]
dt_match <- dt[, .(nrows_other = nrows- .N), by = .(groupId_match = groupId)]
dt_match[, odd := nrows_other %% 2]
dt_match[, idx1 := ceiling(nrows_other/2)]
dt_match[, idx2 := ifelse(odd, idx1, idx1+1)]
setkey(dt, value)
dt_match[, medianOfAllTheOtherGroups := dt[groupId != groupId_match][c(idx1, idx2), sum(value)/2], by = groupId_match]
dt[dt_match, medianOfAllTheOtherGroups := i.medianOfAllTheOtherGroups,
on = c(groupId = "groupId_match")]
Run Code Online (Sandbox Code Playgroud)
我猜想,可能还有更多类似 data.table 的方法可以进一步提高性能。
numberOfGroups = 5000 且 numberOfValuesPerGroup = 50 的内存/运行时:20GB,27000ms