如何有效地检查矩阵是否为二进制形式(例如,所有1或0)?

Csi*_*der 8 r matrix

我有一个函数,它采用mxn大小(可能)二进制矩阵作为输入,如果矩阵包含的数字不是0或1,或者是NA,我想返回错误处理.我怎样才能有效地检查这个?

例如,通过为10 x 10生成一些数据:

> n=10;m=10
> mat = round(matrix(runif(m*n), m, n))
> mat
        [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10]
 [1,]    0    1    0    1    1    0    1    0    1     0
 [2,]    0    0    0    0    0    0    0    0    0     1
 [3,]    1    1    0    1    1    0    0    1    1     0
 [4,]    1    1    1    1    0    1    0    0    1     1
 [5,]    1    1    1    0    0    1    1    1    0     1
 [6,]    1    0    1    0    0    0    0    1    0     0
 [7,]    0    0    0    1    0    1    1    1    1     0
 [8,]    0    0    0    1    0    1    1    1    1     1
 [9,]    0    0    1    1    0    1    1    1    1     1
[10,]    1    0    1    1    0    0    0    0    1     1
Run Code Online (Sandbox Code Playgroud)

应该始终返回矩阵是二进制的,但是可以通过以下方式之一进行更改:

> mat[1,1]=NA
> mat[1,1]=2
Run Code Online (Sandbox Code Playgroud)

应该返回矩阵不是二进制的.

目前,我一直在使用我的功能:

for(i in 1:nrow(mat))
{
    for(j in 1:ncol(mat))
    {
      if(is.na(mat[i,j])|(!(mat[i,j] == 1 | mat[i,j] == 0)))
      {
        stop("Data must be only 0s, 1s")
      }
    }
}
Run Code Online (Sandbox Code Playgroud)

但是单独检查大型矩阵的每个值似乎非常缓慢且效率低下.有没有一种聪明,简单的方法可以做到这一点我错过了?

谢谢

Jam*_*ble 5

以下是一些选项的时间安排(包括其他答案中建议的选项):

n=5000;m=5000
mat = round(matrix(runif(m*n), m, n))
> system.time(stopifnot(sum(mat==0) + sum(mat==1) == length(mat)))
   user  system elapsed 
   0.30    0.02    0.31 
> system.time(stopifnot(all(mat %in% c(0,1))))
   user  system elapsed 
   0.58    0.06    0.63 
> system.time(stopifnot(all(mat==0 | mat==1)))
   user  system elapsed 
   0.77    0.03    0.80 
Run Code Online (Sandbox Code Playgroud)

考虑到它是一个5000乘5000的矩阵,它们都非常快!三者中最快的似乎是:

stopifnot(sum(mat==0) + sum(mat==1) == length(mat))
Run Code Online (Sandbox Code Playgroud)


sgi*_*ibb 5

我喜欢添加一个稍微修改过的版本,sum该版本比 @JamesTrimble 的版本更快。我希望我所有的假设都是正确的:

is.binary.sum2 <- function(x) {
  identical(sum(abs(x)) - sum(x == 1), 0)
}
Run Code Online (Sandbox Code Playgroud)

这是基准:

library(rbenchmark)

n=5000
m=5000
mat = round(matrix(runif(m*n), m, n))

is.binary.sum <- function(x) {
  sum(x == 0) + sum(x == 1) == length(x)
}

is.binary.sum2 <- function(x) {
  identical(sum(abs(x)) - sum(x == 1), 0)
}

is.binary.all <- function(x) {
  all(x == 0 | x == 1)
}

is.binary.in <- function(x) {
  all(x %in% c(0, 1))
}

benchmark(is.binary.sum(mat), is.binary.sum2(mat),
          is.binary.all(mat), is.binary.in(mat),
          order="relative", replications=10)
#                 test replications elapsed relative user.self sys.self user.child sys.child
#2 is.binary.sum2(mat)           10   4.635    1.000     3.872    0.744          0         0
#1  is.binary.sum(mat)           10   7.097    1.531     6.565    0.512          0         0
#4   is.binary.in(mat)           10  10.359    2.235     9.216    1.108          0         0
#3  is.binary.all(mat)           10  12.565    2.711    11.753    0.772          0         0
Run Code Online (Sandbox Code Playgroud)


Car*_*oft 5

我立刻想到了 identical(mat,matrix(as.numeric(as.logical(mat),nr=nrow(mat)) ) )

这使得NA作为NA,所以如果你想找出这样的存在,你只需要一个快速的any(is.na(mat))或类似的测试.

编辑:计时赛

fun2 <- function(x) {
      all(x %in% 0:1)
}
fun1 <-function(x) {identical(as.vector(x),as.numeric(as.logical(x)))}

mfoo<-matrix(sample(0:10,1e6,rep=TRUE),1e3)
 microbenchmark(fun1(mfoo),fun2(mfoo),is.binary.sum2(mfoo),times=10)
Unit: milliseconds
                 expr       min        lq    median        uq
           fun1(mfoo)  2.286941  2.809926  2.835584  2.865518
           fun2(mfoo) 20.369075 20.894627 21.100528 21.226464
 is.binary.sum2(mfoo) 11.394503 12.418238 12.431922 12.458436
       max neval
  2.920253    10
 21.407777    10
 28.316492    10
Run Code Online (Sandbox Code Playgroud)

反对的not...事情:我不得不投入一个,try以避免破坏测试.

notfun <- function(mat) try(stopifnot(sum(mat==0) + sum(mat==1) == length(mat)))
 microbenchmark(fun1(mfoo),notfun(mfoo),is.binary.sum2(mfoo),times=10)
Error : sum(mat == 0) + sum(mat == 1) == length(mat) is not TRUE
##error repeated 10x for the 10 trials
Unit: milliseconds
                 expr       min        lq    median        uq
           fun1(mfoo)  4.870653  4.978414  5.057524  5.268344
         notfun(mfoo) 18.149273 18.685942 18.942518 19.241856
 is.binary.sum2(mfoo) 11.428713 12.145842 12.516165 12.605111
       max neval
  5.438111    10
 34.826230    10
 13.090465    10
Run Code Online (Sandbox Code Playgroud)

我赢了!:-)