R - 获取向量中最大n个元素的索引的最快方法

12 sorting r vector

假设我有一个x包含100万个元素的巨大向量,我想找到最多30个元素的索引.我不特别在意结果是否在这30个元素中排序,只要它们是整个向量中的最大值30.使用order[x][1:30]似乎相当昂贵,因为它必须对整个矢量进行排序.我考虑过使用partial选项sort,但sort返回值,并且在指定index.return时不支持该选项partial.有没有一种有效的方法来查找索引而不对整个向量进行排序?

sgi*_*ibb 12

我想使用sort的partial参数添加混合方法,并且which:

whichpart <- function(x, n=30) {
  nx <- length(x)
  p <- nx-n
  xp <- sort(x, partial=p)[p]
  which(x > xp)
}
Run Code Online (Sandbox Code Playgroud)

一些基准测试:

library("microbenchmark")
library("data.table")
library("compiler")

set.seed(123)
x <- rnorm(1e6)
y <- sample.int(1e6)


whichpart <- function(x, n=30) {
  nx <- length(x)
  p <- nx-n
  xp <- sort(x, partial=p)[p]
  which(x > xp)
}

cpwhichpart <- cmpfun(whichpart)

# using quicksort
quicksort <- function(x, n=30) {
  sort(x, method="quick", decreasing=TRUE, index.return=TRUE)$ix[1:n]
}

cpquicksort <- cmpfun(quicksort)

# @Mariam
whichsort <- function(x, n=30) {
  which(x >= sort(x, decreasing=TRUE)[30], arr.ind=TRUE)
}

cpwhichsort <- cmpfun(whichsort)

# @Ferdinand.kraft
top <- function(x, n=30) {
    result <- numeric()
    for(i in 1:n){
        j <- which.max(x)
        result[i] <- j
        x[j] <- -Inf
    }
    result
}

cptop <- cmpfun(top)

# @Tony Breyal
dtable <- function(x, n=30) {
  dt <- data.table(x=x, x.index=seq.int(x))
  setkey(dt, "x")
  dt$x.index[1:n]
}

cpdtable <- cmpfun(dtable)

# @Roland
roland <- cmpfun(function(x, n=30) {
  y <- rep(-Inf, n)
  for (i in seq_along(x)) {
    if (x[i] > y[1]) {
      y[1] <- x[i]
      y <- y[order(y)]
    }
  }
  y
})

## rnorm
microbenchmark(whichpart(x), cpwhichpart(x),
               quicksort(x), cpquicksort(x),
               whichsort(x), cpwhichsort(x),
               top(x), cptop(x),
               dtable(x), cpdtable(x),
               roland(x), times=10)

# Unit: milliseconds
#            expr        min         lq     median         uq        max neval
#    whichpart(x)   45.63544   46.05638   47.09077   49.68452   51.42065    10
#  cpwhichpart(x)   45.65996   45.77212   47.02808   48.07482   82.20458    10
#    quicksort(x)  100.90936  103.00783  105.17506  109.31784  139.83518    10
#  cpquicksort(x)  100.53958  102.78017  107.64470  138.96630  142.52882    10
#    whichsort(x)  148.86010  151.04350  155.80871  159.47063  184.56697    10
#  cpwhichsort(x)  149.05578  150.21183  151.36918  166.58342  173.87567    10
#          top(x)  146.10757  182.42089  184.53050  191.37293  193.62272    10
#        cptop(x)  155.14354  179.14847  184.52323  196.80644  220.21222    10
#       dtable(x) 1041.32457 1042.54904 1049.26096 1065.40606 1080.89969    10
#     cpdtable(x) 1042.08247 1043.54915 1051.76366 1084.14360 1310.26485    10
#       roland(x)  251.42885  261.47608  273.20838  295.09733  323.96257    10

## integer
microbenchmark(whichpart(y), cpwhichpart(y),
               quicksort(y), cpquicksort(y),
               whichsort(y), cpwhichsort(y),
               top(y), cptop(y),
               dtable(y), cpdtable(y),
               roland(y), times=10)

# Unit: milliseconds
#            expr       min        lq    median        uq       max neval
#    whichpart(y)  11.60703  11.76857  12.03704  12.52871  47.88526    10
#  cpwhichpart(y)  11.62885  11.75006  12.53724  13.88563  46.93677    10
#    quicksort(y)  88.14924  89.47630  92.42414 103.53439 137.44335    10
#  cpquicksort(y)  88.11544  89.15334  92.63420  94.42244 133.78006    10
#    whichsort(y) 122.34675 123.13634 124.91990 127.79134 131.43400    10
#  cpwhichsort(y) 121.85618 122.91653 125.45211 127.14112 158.61535    10
#          top(y) 163.06669 181.19004 211.11557 224.19237 239.63139    10
#        cptop(y) 163.37903 173.55113 209.46770 218.59685 226.81545    10
#       dtable(y) 499.50807 505.45513 514.55338 537.84129 604.86454    10
#     cpdtable(y) 491.70016 498.62664 525.05342 527.14666 580.19429    10
#       roland(y) 235.44664 237.52200 242.87925 268.34080 287.71196    10


identical(sort(quicksort(x)), whichpart(x))
# [1] TRUE
Run Code Online (Sandbox Code Playgroud)

编辑:测试@ flodel的建议

# @flodel
whichpartrev <- function(x, n=30) {
  which(x >= -sort(-x, partial=n)[n])
}

microbenchmark(whichpart(x), whichpartrev(x), times=100)

# Unit: milliseconds
#             expr      min       lq   median       uq      max neval
#     whichpart(x) 45.44940 46.15011 46.51321 48.67986 80.63286   100
#  whichpartrev(x) 28.84482 31.30661 32.87695 62.37843 67.84757   100

microbenchmark(whichpart(y), whichpartrev(y), times=100)

# Unit: milliseconds
#             expr      min       lq   median       uq      max neval
#     whichpart(y) 11.56135 12.26539 13.05729 13.75199 43.78484   100
#  whichpartrev(y) 16.00612 16.73690 17.71687 19.04153 49.02842   100
Run Code Online (Sandbox Code Playgroud)