CUDA thrust remove_if с последовательностью трафаретов

0

Я пытался удалить элементы из двух параметров thrust::device_vector<int> на основе значений первого вектора. Интуитивно я создал следующий снимок:

thrust::device_vector<float> idxToValue(COUNT_MAX);
thrust::device_vector<int> idxSorted(COUNT_MAX);
thrust::device_vector<int> groupIdxSorted(COUNT_MAX);
int count = COUNT_MAX;
float const minThreshold = MIN_THRESHOLD;

auto idxToValueSortedIter = thrust::make_permutation_iterator(
    idxToValue.begin()
    , idxSorted.begin()
    );

auto new_end = thrust::remove_if(
    thrust::make_zip_iterator(thrust::make_tuple(idxSorted.begin(), groupIdxSorted.begin()))
    , thrust::make_zip_iterator(thrust::make_tuple(idxSorted.begin() + count, groupIdxSorted.begin() + count))
    , idxToValueSortedIter 
    , thrust::placeholders::_1 >= minThreshold
    );

count = thrust::get<0>(new_end.get_iterator_tuple()) - idxSorted.begin();

К сожалению, документация Thrust

Диапазон [трафарет, трафарет + (последний - первый)) не должен перекрывать диапазон [результат, результат + (последний - первый))

Поэтому в моем случае idxToValueSortedIter, который используется в качестве последовательности трафаретов, зависит от idxSorted и фактически перекрывает результат (тот же вектор).

Есть ли способ решить эту проблему без копирования данных во временный вектор?

Теги:
cuda
thrust

1 ответ

2
Лучший ответ

Я думаю, что вы можете сделать это, просто используя remove_if версию remove_if (без трафарета, у нее нет такого ограничения на перекрытие трафарета с выходной последовательностью) и передачи вашего трафарета (т.е. Вашего итератора перестановки) как 3-й член вашего zip_iterator для remove_if плюс подходящий функтор выбора. Вот пример:

$ cat t572.cu
#include <iostream>
#include <thrust/device_vector.h>
#include <thrust/remove.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/iterator/permutation_iterator.h>
#include <thrust/copy.h>

#define COUNT_MAX 10
#define MIN_THRESHOLD 4.5f

struct my_functor
{
  float thresh;
  my_functor(float _thresh): thresh(_thresh) {}

  template <typename T>
  __host__ __device__
  bool operator()(T &mytuple) const {
    return thrust::get<2>(mytuple) > thresh;
  }
};

int main(){

  float h_idxToValue[COUNT_MAX] = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f};
  int   h_idxSorted[COUNT_MAX] = {9, 8, 7, 6, 5, 4, 3, 2, 1, 0};
  int   h_groupIdxSorted[COUNT_MAX] = {20, 21, 22, 23, 24, 25, 26, 27, 28, 29};

  thrust::device_vector<float> idxToValue(h_idxToValue, h_idxToValue + COUNT_MAX);
  thrust::device_vector<int> idxSorted(h_idxSorted, h_idxSorted + COUNT_MAX);
  thrust::device_vector<int> groupIdxSorted(h_groupIdxSorted, h_groupIdxSorted + COUNT_MAX);
  int count = COUNT_MAX;
  float const minThreshold = MIN_THRESHOLD;

  auto new_end = thrust::remove_if(
    thrust::make_zip_iterator(thrust::make_tuple(idxSorted.begin(), groupIdxSorted.begin(), thrust::make_permutation_iterator(idxToValue.begin(), idxSorted.begin())))
    , thrust::make_zip_iterator(thrust::make_tuple(idxSorted.begin() + count, groupIdxSorted.begin() + count, thrust::make_permutation_iterator(idxToValue.begin(), idxSorted.begin() + count)))
    , my_functor(minThreshold)
    );

  count = thrust::get<0>(new_end.get_iterator_tuple()) - idxSorted.begin();

  std::cout << "count = " << count << std::endl;
  thrust::copy_n(groupIdxSorted.begin(), count, std::ostream_iterator<int>(std::cout, ","));
  std::cout << std::endl;
  return 0;
}

$ nvcc -arch=sm_20 -std=c++11 -o t572 t572.cu
$ ./t572
count = 5
25,26,27,28,29,
$

Обычно мы ожидаем, что функция remove_if с предоставленным функтором удалит записи, значение idxToValue которых превышает пороговое значение (4.5). Однако из-за итератора перестановки и обратной последовательности упорядочения в idxSorted мы видим, что значения выше порога сохраняются, а остальные удаляются. Приведенный выше пример был с CUDA 6.5 и Fedora 20, чтобы воспользоваться экспериментальной поддержкой С++ 11.

  • 0
    Я на самом деле искал способ избежать пользовательского функтора (то есть с помощью заполнителей), но это то, что я просил, и это работает! Спасибо!

Ещё вопросы

Сообщество Overcoder
Наверх
Меню