我有两个整数数组,dmap 并且dflag 在相同长度的设备上,我用推力设备指针包裹它们,dmapt并且
dflagt
dmap数组中有一些值为-1的元素.我想从dflag数组中删除这些-1和相应的值.
我正在使用remove_if函数来执行此操作,但我无法弄清楚此调用的返回值是什么,或者我应该如何使用此返回值来获取.
(我想将这些简化的数组传递给将reduce_by_keydflagt用作键的函数.)
我正在使用以下调用进行缩减.请告诉我如何将返回值存储在变量中并使用它来处理各个数组dflag和dmap
thrust::remove_if(
thrust::make_zip_iterator(thrust::make_tuple(dmapt, dflagt)),
thrust::make_zip_iterator(thrust::make_tuple(dmapt+numindices, dflagt+numindices)),
minus_one_equality_test()
);
Run Code Online (Sandbox Code Playgroud)
将上面使用的谓词仿函数定义为
struct minus_one_equality_test
{
typedef typename thrust::tuple<int,int> Tuple;
__host__ __device__
bool operator()(const Tuple& a )
{
return thrust::get<0>(a) == (-1);
}
}
Run Code Online (Sandbox Code Playgroud)
返回值是一个zip_iterator,它标记了在remove_if调用期间,函子返回true的元组序列的新结尾.要访问底层数组的新结束迭代器,您需要从zip_iterator中检索元组迭代器; 然后,该元组的内容是用于构建zip_iterator的原始数组的新结束迭代器.在单词中比在代码中更复杂:
#include <thrust/tuple.h>
#include <thrust/device_vector.h>
#include <thrust/device_ptr.h>
#include <thrust/remove.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/copy.h>
#include <iostream>
struct minus_one_equality_test
{
typedef thrust::tuple<int,int> Tuple;
__host__ __device__
bool operator()(const Tuple& a )
{
return thrust::get<0>(a) == (-1);
};
};
int main(void)
{
const int numindices = 10;
int mapt[numindices] = { 1, 2, -1, 4, 5, -1, 7, 8, -1, 10 };
int flagt[numindices] = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };
thrust::device_vector<int> vmapt(10);
thrust::device_vector<int> vflagt(10);
thrust::copy(mapt, mapt+numindices, vmapt.begin());
thrust::copy(flagt, flagt+numindices, vflagt.begin());
thrust::device_ptr<int> dmapt = vmapt.data();
thrust::device_ptr<int> dflagt = vflagt.data();
typedef thrust::device_vector< int >::iterator VIt;
typedef thrust::tuple< VIt, VIt > TupleIt;
typedef thrust::zip_iterator< TupleIt > ZipIt;
ZipIt Zend = thrust::remove_if(
thrust::make_zip_iterator(thrust::make_tuple(dmapt, dflagt)),
thrust::make_zip_iterator(thrust::make_tuple(dmapt+numindices, dflagt+numindices)),
minus_one_equality_test()
);
TupleIt Tend = Zend.get_iterator_tuple();
VIt vmapt_end = thrust::get<0>(Tend);
for(VIt x = vmapt.begin(); x != vmapt_end; x++) {
std::cout << *x << std::endl;
}
return 0;
}
Run Code Online (Sandbox Code Playgroud)
如果你编译并运行它,你应该看到这样的东西:
$ nvcc -arch=sm_12 remove_if.cu
$ ./a.out
1
2
4
5
7
8
10
Run Code Online (Sandbox Code Playgroud)
在这个例子中,我只"检索"元组的第一个元素的短路内容,第二个元素以相同的方式访问,即.标记向量新结尾的迭代器是thrust::get<1>(Tend).