我需要访问包含MatrixBase Eigen矩阵数据的数组.
Eigen库有data()方法,它返回一个指向数组的指针,但只能从Matrix 类型访问它.该MatrixBase没有类似的方法,即使MatrixBase类应该充当模板和实际类型应该只是一个矩阵.如果我尝试访问MatrixBase.data(),我得到一个编译时错误:
template <typename ScalarA, typename Index, typename DerivedB, typename DerivedC>
void uscgemv(float alpha,
const USCMatrix<ScalarA,Index> &a,
const MatrixBase<DerivedB> &b,
const MatrixBase<DerivedC> &c_const)
{
//...some code
float * bMat = b.data();
///more code
}
Run Code Online (Sandbox Code Playgroud)
此代码生成以下编译时错误.
error: ‘const class Eigen::MatrixBase<Eigen::CwiseNullaryOp<Eigen::internal::scalar_constant_op<float>, Eigen::Matrix<float, -1, 1> > >’ has no member named ‘data’
float * bMat = b.data();
Run Code Online (Sandbox Code Playgroud)
所以我不得不诉诸噱头......
float * bMat;
int bRows = b.rows();
int bCols = b.cols();
mallocPinnedMemory(&bMat, bRows*bCols*sizeof(float));
Eigen::Map<Matrix<float, Dynamic, Dynamic> > bmat_temp(bMat, bRows, bCols);
bmat_temp = b; //THis is SLOW, we should avoid it.
Run Code Online (Sandbox Code Playgroud)
然后我可以访问bMat数组......
那些副本来回是gpu矩阵乘法的最大成本,因为我基本上必须制作一个额外的副本,甚至在应对设备之前......
我不能使用Eigen-magma,因为这是稀疏矩阵 - 奇怪的格式到密集矩阵(或有时矢量)乘法,所以我不能使用那里的任何自动gpu函数.另外,我宁愿不将矩阵声明为其他东西,因为这需要在整个程序中更改大量代码行(我没有写过).
编辑:提出了静态演员解决方案:
float * bMat = (static_cast<Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic> >(b)).data();
Run Code Online (Sandbox Code Playgroud)
但是,当我第一次尝试访问数组bMat的元素时,我得到段错误.
编辑2:我正在寻找一种零拷贝方式来访问底层数组.我只需要能够读取b,但我还需要能够写入c.目前c与以下宏不对称:
#define UNCONST(t,c,uc) Eigen::MatrixBase<t> &uc = const_cast<Eigen::MatrixBase<t>&>(c);
Run Code Online (Sandbox Code Playgroud)
编辑3:交叉发布到Eigen论坛后,似乎我做得不比建议的答案好.
MatrixBase是任何密集表达式的基类.它不一定对应于具有存储的对象.例如,可以是A+B具有常量值的向量的抽象表示,或者在您的情况下是抽象表示.您可以使uscgemv仅接受使用Ref<>该类具有适当存储的表达式,例如:
template <typename ScalarA, typename Index>
void uscgemv(float alpha,
const USCMatrix<ScalarA,Index> &a,
Ref<const VectorXf> b,
Ref<VectorXf> c);
Run Code Online (Sandbox Code Playgroud)
如果第三个参数与a的存储不匹配,VectorXf那么它将为您进行评估.然后你可以安全地打电话b.data().要保持标量类型的b泛型,您仍然可以MatrixBase<DerivedB>&将其声明为然后将其复制到Ref<const Matrix<typename DerivedB::Scalar, DerivedB::RowsAtCompileTime, DerivedB::ColsAtCompileTime> >:
typedef Ref<const Matrix<typename DerivedB::Scalar, DerivedB::RowsAtCompileTime, DerivedB::ColsAtCompileTime> > RefB;
RefB actual_b(b);
actual_b.data();
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
6725 次 |
| 最近记录: |