Cuda Thrust 自定义函数

发布于 2024-12-05 06:21:04 字数 463 浏览 1 评论 0原文

如何在 Thrust 中实现这个功能?

for (i=0;i<n;i++)
    if (i==pos)
        h1[i]=1/h1[i];
    else
        h1[i]=-h1[i]/value;

在 CUDA 中我这样做是这样的:

__global__ void inverse_1(double* h1, double value, int pos, int N)
{
    int i = blockDim.x * blockIdx.x + threadIdx.x;
    if (i < N){
        if (i == pos)
            h1[i] = 1 / h1[i];
        else
            h1[i] = -h1[i] / value;
    }
}

谢谢!

How can I impliment this function in Thrust?

for (i=0;i<n;i++)
    if (i==pos)
        h1[i]=1/h1[i];
    else
        h1[i]=-h1[i]/value;

In CUDA I did it like:

__global__ void inverse_1(double* h1, double value, int pos, int N)
{
    int i = blockDim.x * blockIdx.x + threadIdx.x;
    if (i < N){
        if (i == pos)
            h1[i] = 1 / h1[i];
        else
            h1[i] = -h1[i] / value;
    }
}

Thanks!

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

一个人的旅程 2024-12-12 06:21:05

您需要创建一个二元仿函数来应用该操作,然后使用计数迭代器作为第二个输入。您可以将 posvalue 传递到仿函数的构造函数中。它看起来像:

struct inv1_functor
{
  const int pos;
  const double value;

  inv1_functor(double _value, int _pos) : value(_value), pos(_pos) {}

  __host__ __device__
  double operator()(const double &x, const int &i) const {
    if (i == pos)
      return 1.0/x;
    else
      return -x/value;
  }
};

//...

thrust::transform(d_vec.begin(), d_vec.end(), thrust::counting_iterator<int>(),  d_vec.begin(), inv1_functor(value, pos));

You need to create a binary functor to apply the operation, then use a counting iterator as the second input. You can pass pos and value into the functor's constructor. It'd look something like:

struct inv1_functor
{
  const int pos;
  const double value;

  inv1_functor(double _value, int _pos) : value(_value), pos(_pos) {}

  __host__ __device__
  double operator()(const double &x, const int &i) const {
    if (i == pos)
      return 1.0/x;
    else
      return -x/value;
  }
};

//...

thrust::transform(d_vec.begin(), d_vec.end(), thrust::counting_iterator<int>(),  d_vec.begin(), inv1_functor(value, pos));
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文