我有一个很长的参数向量(大约4 ^ 10个元素)和一个索引向量.我的目标是将索引向量中索引的所有参数值加在一起.
例如,如果我有para = [1,2,3,4,5,5,5]和indices = [3,3,1,6]那么我想找到第三个值的累积和(3 )两次,第一个值(1)和第六个(5),得到12.另外还有根据它们的位置扭曲参数值的选项.
我正在尝试加速R实现,因为我称之为数百万次.
我当前的代码总是返回NA,我无法看到它出错的地方
这是Rcpp函数:
double dot_prod_c(NumericVector indices, NumericVector paras,
NumericVector warp = NA_REAL) {
int len = indices.size();
LogicalVector indices_ok;
for (int i = 0; i < len; i++){
indices_ok.push_back(R_IsNA(indices[i]));
}
if(is_true(any(indices_ok))){
return NA_REAL;
}
double counter = 0;
if(NumericVector::is_na(warp[1])){
for (int i = 0; i < len; i++){
counter += paras[indices[i]];
}
} else {
for (int i = 0; i < len; i++){
counter += paras[indices[i]] * warp[i];
}
}
return counter;
}
Run Code Online (Sandbox Code Playgroud)
这是工作R版本:
dot_prod <- function(indices, paras, warp = NA){
if(is.na(warp[1])){
return(sum(sapply(indices, function(ind) paras[ind + 1])))
} else {
return(sum(sapply(1:length(indices), function(i){
ind <- indices[i]
paras[ind + 1] * warp[i]
})))
}
}
Run Code Online (Sandbox Code Playgroud)
以下是使用microbenchmark软件包进行测试和基准测试的一些代码:
# testing
library(Rcpp)
library(microbenchmark)
parameters <- list()
indices <- list()
indices_trad <- list()
set.seed(2)
for (i in 4:12){
size <- 4^i
window_size <- 100
parameters[[i-3]] <- runif(size)
indices[[i-3]] <- floor(runif(window_size)*size)
temp <- rep(0, size)
for (j in 1:window_size){
temp[indices[[i-3]][j] + 1] <- temp[indices[[i-3]][j] + 1] + 1
}
indices_trad[[i-3]] <- temp
}
microbenchmark(
x <- sapply(1:9, function(i) dot_prod(indices[[i]], parameters[[i]])),
x_c <- sapply(1:9, function(i) dot_prod_c(indices[[i]], parameters[[i]])),
x_base <- sapply(1:9, function(i) indices_trad[[i]] %*% parameters[[i]])
)
all.equal(x, x_base) # is true, does work
all.equal(x_c, x_base) # not true - C++ version returns only NAs
Run Code Online (Sandbox Code Playgroud)
我试图通过你的代码来解释你的总体目标时遇到了一些麻烦,所以我只想解释一下这个问题
例如,如果我有para = [1,2,3,4,5,5,5]和indices = [3,3,1,6]那么我想找到第三个值的累积和(3 )两次,第一个值(1)和第六个(5),得到12.另外还有根据它们的位置扭曲参数值的选项.
因为我最清楚.
您的C++代码存在一些问题.首先,不要这样做NumericVector warp = NA_REAL- 使用Rcpp::Nullable<>模板(如下所示).这将解决一些问题:
Nullable类,它几乎就是它听起来的样子 - 一个可能是也可能不是null的对象.NumericVector warp = NA_REAL.坦率地说,我很惊讶编译器接受了这一点.if(NumericVector::is_na(warp[1])){.这有不明确的行为写在它上面.这是一个修订版本,取消了您对上述问题的引用说明:
#include <Rcpp.h>
typedef Rcpp::Nullable<Rcpp::NumericVector> nullable_t;
// [[Rcpp::export]]
double DotProd(Rcpp::NumericVector indices, Rcpp::NumericVector params, nullable_t warp_ = R_NilValue) {
R_xlen_t i = 0, n = indices.size();
double result = 0.0;
if (warp_.isNull()) {
for ( ; i < n; i++) {
result += params[indices[i]];
}
} else {
Rcpp::NumericVector warp(warp_);
for ( ; i < n; i++) {
result += params[indices[i]] * warp[i];
}
}
return result;
}
Run Code Online (Sandbox Code Playgroud)
您有一些精心设计的代码来生成示例数据.我没有花时间来完成这个,因为没有必要,基准测试也没有.您自己说过C++版本没有产生正确的结果.您的首要任务应该是让您的代码处理简单数据.然后给它提供一些更复杂的数据.然后基准.上面的修订版本适用于简单数据:
args <- list(
indices = c(3, 3, 1, 6),
params = c(1, 2, 3, 4, 5, 5, 5),
warp = c(.25, .75, 1.25, 1.75)
)
all.equal(
DotProd(args[[1]], args[[2]]),
dot_prod(args[[1]], args[[2]]))
#[1] TRUE
all.equal(
DotProd(args[[1]], args[[2]], args[[3]]),
dot_prod(args[[1]], args[[2]], args[[3]]))
#[1] TRUE
Run Code Online (Sandbox Code Playgroud)
它也比此样本数据上的R版本更快.我没有理由相信它不适用于更大,更复杂的数据 - *apply函数没有什么神奇或特别的效率; 它们只是更惯用/可读的R.
microbenchmark::microbenchmark(
"Rcpp" = DotProd(args[[1]], args[[2]]),
"R" = dot_prod(args[[1]], args[[2]]))
#Unit: microseconds
#expr min lq mean median uq max neval
#Rcpp 2.463 2.8815 3.52907 3.3265 3.8445 18.823 100
#R 18.869 20.0285 21.60490 20.4400 21.0745 66.531 100
#
microbenchmark::microbenchmark(
"Rcpp" = DotProd(args[[1]], args[[2]], args[[3]]),
"R" = dot_prod(args[[1]], args[[2]], args[[3]]))
#Unit: microseconds
#expr min lq mean median uq max neval
#Rcpp 2.680 3.0430 3.84796 3.701 4.1360 12.304 100
#R 21.587 22.6855 23.79194 23.342 23.8565 68.473 100
Run Code Online (Sandbox Code Playgroud)
我省略了NA上面例子中的检查,但是通过使用一点Rcpp糖也可以修改为更惯用的东西.以前,你这样做:
LogicalVector indices_ok;
for (int i = 0; i < len; i++){
indices_ok.push_back(R_IsNA(indices[i]));
}
if(is_true(any(indices_ok))){
return NA_REAL;
}
Run Code Online (Sandbox Code Playgroud)
它有点咄咄逼人 - 你正在测试一个完整的值向量(带R_IsNA),然后应用is_true(any(indices_ok))- 当你可能过早地破坏并返回NA_REAL第一个R_IsNA(indices[i])导致的实例时true.此外,使用push_back会减慢你的功能 - 你最好初始化indices_ok到已知的大小并通过循环中的索引访问来填充它.不过,这是压缩操作的一种方法:
if (Rcpp::na_omit(indices).size() != indices.size()) return NA_REAL;
Run Code Online (Sandbox Code Playgroud)
为了完整起见,这里有一个完全糖化的版本,可以让你完全避免循环:
#include <Rcpp.h>
typedef Rcpp::Nullable<Rcpp::NumericVector> nullable_t;
// [[Rcpp::export]]
double DotProd3(Rcpp::NumericVector indices, Rcpp::NumericVector params, nullable_t warp_ = R_NilValue) {
if (Rcpp::na_omit(indices).size() != indices.size()) return NA_REAL;
if (warp_.isNull()) {
Rcpp::NumericVector tmp = params[indices];
return Rcpp::sum(tmp);
} else {
Rcpp::NumericVector warp(warp_), tmp = params[indices];
return Rcpp::sum(tmp * warp);
}
}
/*** R
all.equal(
DotProd3(args[[1]], args[[2]]),
dot_prod(args[[1]], args[[2]]))
#[1] TRUE
all.equal(
DotProd3(args[[1]], args[[2]], args[[3]]),
dot_prod(args[[1]], args[[2]], args[[3]]))
#[1] TRUE
*/
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
1703 次 |
| 最近记录: |