选择作为给定向量的排列的矩阵行

Ric*_*rdo 10 r permutation matrix subset

我有一个矩阵X

     one two three four
 [1,]  1   3   2   4
 [2,]  2   0   1   5
 [3,]  3   2   1   4
 [4,]  4   9   11  19
 [5,]  4   3   2   1
Run Code Online (Sandbox Code Playgroud)

我想要一个新的矩阵Y,其中仅包含“1”、“2”、“3”、“4”的排列行。那是:

     one two three four
 [1,]  1   3   2   4
 [3,]  3   2   1   4
 [5,]  4   3   2   1
Run Code Online (Sandbox Code Playgroud)

我应该使用什么函数或命令?

Sté*_*ent 9

mat <- rbind(
    c(1, 3, 2, 4),
    c(2, 0, 1, 5),
    c(3, 2, 1, 4)
)

ok <- apply(mat, 1L, function(x) setequal(x, c(1, 2, 3, 4)))
mat[ok, ]
Run Code Online (Sandbox Code Playgroud)


李哲源*_*李哲源 7

您的示例矩阵和目标向量:

X <- structure(c(1, 2, 3, 4, 4, 3, 0, 2, 9, 3, 2, 1, 1, 11, 2, 4, 5, 4, 19, 1),
               dim = 5:4)
v <- 1:4
Run Code Online (Sandbox Code Playgroud)

但让我们构建一个更具挑战性的(感谢用户h​​arre):

X <- rbind(X, 1, c(1, 2, 1, 2))
Run Code Online (Sandbox Code Playgroud)

完全矢量化的方法(使用matrixStats包)

rk <- matrix(match(X, v, nomatch = 0L), nrow(X), ncol(X))
ct <- matrixStats::rowTabulates(rk, values = 1:length(v))
zo <- matrixStats::rowCounts(ct, value = 0L)

## all rows that are permutations of 'v'
X[zo == 0L, ]
#     [,1] [,2] [,3] [,4]
#[1,]    1    3    2    4
#[2,]    3    2    1    4
#[3,]    4    3    2    1

## remove rows that are permutations of 'v'
X[zo > 0L, ]
Run Code Online (Sandbox Code Playgroud)

另一种完全矢量化的方法(基础 R)

这是一个数学解。对于非线性和非对称权重函数w(x),以下加权和:

1 xw(1) + 2 xw(2) + 3 xw(3) + 4 xw(4)

是唯一的分数或标识符并且对于排列是不变的。例如,以下给出相同的值:

2 xw(2) + 1 xw(1) + 3 xw(3) + 4 xw(4)

但其他任何东西都会给出不同的值,例如:

1 xw(1) + 3 xw(1) + 3 xw(3) + 4 xw(4)

0 xw(0) + 3 xw(1) + 0 xw(0) + 4 xw(4)

这是使用余弦权重的实现。即使 和 是浮点数或字符,它也可以X工作v

## method = "tab" for tabulation method
## method = "cos" for cosine weights method
FindPerm <- function (X, v, method) {
  ## number of elements
  n <- length(v)
  if (ncol(X) != n) stop("inconformable 'X' and 'v'!")
  if (anyDuplicated(v)) stop("'v' should not contain duplicated values!")
  ## convert everything to integers 0, 1, 2, ..., n
  Xi <- matrix(match(X, v, nomatch = 0L), nrow(X), ncol(X))
  vi <- 1:n
  ## branches
  if (method == "tab") {
    ## row-wise tabulating
    rtab <- matrixStats::rowTabulates(Xi, values = vi)
    ## the i-th value is TRUE if X[i, ] is a permutation of v
    matrixStats::rowCounts(rtab, value = 0L) == 0L
  } else if (method == "cos") {
    ## evaluate cosine weights for Xi and vi
    w <- pi / (n + 1)
    cos.Xi <- cos(w * Xi)
    cos.vi <- cos(w * vi)
    ## weighted sum for Xi
    wtsum.Xi <- rowSums(Xi * cos.Xi)
    ## weighted sum for vi
    wtsum.vi <- sum(vi * cos.vi)
    ## the i-th value is TRUE if X[i, ] is a permutation of v
    wtsum.Xi == wtsum.vi
  } else {
    stop("unknown method!")
  }
}
Run Code Online (Sandbox Code Playgroud)
X[FindPerm(X, v, "cos"), ]
#     [,1] [,2] [,3] [,4]
#[1,]    1    3    2    4
#[2,]    3    2    1    4
#[3,]    4    3    2    1
Run Code Online (Sandbox Code Playgroud)

基准

性能取决于 中值的数量v。制表方法会随着变长而变慢v

## a benchmark function, relying on package "microbenchmark"
## nr: number of matrix rows
## nc: number of elements in 'v'
bm <- function (nr, nc) {
  X <- matrix(sample.int(nc + 1L, nr * nc, replace = TRUE), nr)
  v <- 1:nc
  microbenchmark::microbenchmark("tab" = FindPerm(X, v, "tab"),
                                 "cos" = FindPerm(X, v, "cos"),
                                 check = "identical")
}

bm(2e+4, 4)
#Unit: milliseconds
# expr      min       lq     mean   median       uq      max
#  tab 4.302674 4.324236 4.536260 4.336955 4.359814 7.039699 
#  cos 4.846893 4.872361 5.163209 4.882942 4.901288 7.837580

bm(2e+4, 20)
#Unit: milliseconds
# expr      min       lq     mean   median       uq       max
#  tab 30.63438 30.70217 32.73508 30.77588 33.08046 135.64322
#  cos 21.16669 21.26161 22.28298 21.37563 23.60574  26.31775 
Run Code Online (Sandbox Code Playgroud)


jbl*_*d94 5

由于人们对这个问题非常感兴趣,因此进行更新,这里有一种使用索引的方法来提高李哲元对我原始答案的出色概括的速度。

这个想法是在一个维数组上索引length(v)small v,或者v*sin(w*v)使用结果索引match而不是计算当islargeX*sin(W*X)时:v

library(RcppAlgos)

# simplified version of Zheyuan Li's function
f1 <- function(X, v) {
  n <- length(v)
  Xi <- matrix(match(X, v, nomatch = 0L), nrow(X), ncol(X))
  vi <- 1:n
  w <- pi/(n + 1)
  rowSums(Xi*sin(Xi*w)) == sum(vi*sin(vi*w))
}

f2 <- function(X, v) {
  n <- length(v)
  
  if (n < 6) {
    # index an n-dimensional array
    m <- array(FALSE, rep(n + 1L, n))
    m[permuteGeneral(n)] <- TRUE
    X[] <- match(X, v, nomatch = length(v) + 1L)
    m[X]
  } else {
    nn <- 1:n
    u <- c(nn*sin(pi*nn/(n + 1L)), 0)
    X[] <- u[match(X, v, nomatch = n + 1L)]
    rowSums(X) == sum(u)
  }
}

set.seed(123)
# using Zheyuan Li's test dataset
nr <- 2000; nc <- 4
X <- matrix(sample.int(nc + 1L, nr * nc, replace = TRUE), nr)
v <- 1:nc

microbenchmark::microbenchmark(f1 = f1(X, v),
                               f2 = f2(X, v),
                               check = "identical")
#> Unit: microseconds
#>  expr   min     lq    mean median     uq    max neval
#>    f1 344.4 367.25 438.932 374.05 386.75 5960.6   100
#>    f2  81.9  85.00 163.332  88.90  98.50 6924.4   100

# Zheyuan Li's larger test dataset
set.seed(123)
nr <- 2000; nc <- 20
X <- matrix(sample.int(nc + 1L, nr * nc, replace = TRUE), nr)
v <- 1:nc

microbenchmark::microbenchmark(f1 = f1(X, v),
                               f2 = f2(X, v),
                               check = "identical")
#> Unit: microseconds
#>  expr    min      lq     mean  median     uq    max neval
#>    f1 1569.2 1575.45 1653.510 1601.30 1683.6 3972.6   100
#>    f2  355.2  359.90  431.705  366.85  408.6 2253.8   100
Run Code Online (Sandbox Code Playgroud)

原始答案已编辑使用X + exp(1/X)(请参阅评论)。

这应该适用于正整数:

Y <- X[rowSums(X + exp(1/X)) == sum(1:4 + exp(1/(1:4))),]
Run Code Online (Sandbox Code Playgroud)

针对apply解决方案进行基准测试:

f1 <- function(x) x[apply(x, 1L, function(x) setequal(x, 1:4)),]
f2 <- function(x) x[rowSums(x + exp(1/x)) == sum(1:4 + exp(1/(1:4))),]

X <- matrix(sample(10, 4e5, TRUE), 1e5)
microbenchmark::microbenchmark(f1 = f1(X),
                               f2 = f2(X),
                               times = 10,
                               check = "equal")
#> Unit: milliseconds
#>  expr      min       lq      mean    median       uq      max neval
#>    f1 448.2680 450.8778 468.55179 461.62620 472.0022 542.0455    10
#>    f2  28.5362  28.6889  31.50941  29.44845  30.2693  50.4402    10
Run Code Online (Sandbox Code Playgroud)


Tho*_*ing 3

我们可以试试这个

> mat[colSums(mapply(`%in%`, list(1:4), asplit(mat, 1))) == ncol(mat), ]
     [,1] [,2] [,3] [,4]
[1,]    1    3    2    4
[2,]    3    2    1    4
[3,]    4    3    2    1
Run Code Online (Sandbox Code Playgroud)