Skip to content

Commit

Permalink
Disable iterator support in foreach (for now)
Browse files Browse the repository at this point in the history
  • Loading branch information
wmaxey committed Oct 4, 2024
1 parent 627bd3d commit 6fb017e
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 18 deletions.
5 changes: 5 additions & 0 deletions c/src/for.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ extern "C" CCCL_C_API CUresult cccl_device_for_build(
{
CUresult error = CUDA_SUCCESS;

if (d_data.type == cccl_iterator_kind_t::iterator)
{
throw std::runtime_error(std::string("Iterators are unsupported in for_each currently"));
}

try
{
nvrtcProgram prog{};
Expand Down
9 changes: 4 additions & 5 deletions c/src/for/for_op_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include <cstdlib>
#include <cstring>
#include <format>
#include <memory>
#include <string>
#include <string_view>
#include <type_traits>
Expand Down Expand Up @@ -86,9 +85,9 @@ static std::string get_for_kernel_user_op(cccl_op_t user_op, cccl_iterator_t ite
#define _USER_OP_INPUT_T {2}
#if defined(_STATEFUL_USER_OP)
extern "C" __device__ void _USER_OP(void *, _USER_OP_INPUT_T);
extern "C" __device__ void _USER_OP(void*, _USER_OP_INPUT_T*);
#else
extern "C" __device__ void _USER_OP(_USER_OP_INPUT_T);
extern "C" __device__ void _USER_OP(_USER_OP_INPUT_T*);
#endif
#if defined(_STATEFUL_USER_OP)
Expand All @@ -98,7 +97,7 @@ struct __align__({3}) user_op_t {{
struct user_op_t {{
#endif
__device__ void operator()(_USER_OP_INPUT_T input) {{
__device__ void operator()(_USER_OP_INPUT_T* input) {{
#if defined(_STATEFUL_USER_OP)
_USER_OP(&data, input);
#else
Expand Down Expand Up @@ -148,7 +147,7 @@ struct for_each_wrapper
__device__ void operator()(unsigned long long idx)
{{
user_op(iterator[idx]);
user_op(iterator + idx);
}}
}};
Expand Down
10 changes: 5 additions & 5 deletions c/test/c2h.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,15 +183,15 @@ static std::string get_for_op(cccl_type_enum t)
switch (t)
{
case cccl_type_enum::INT8:
return "extern \"C\" __device__ void op(char a) {}";
return "extern \"C\" __device__ void op(char* a) {(*a)++;}";
case cccl_type_enum::INT32:
return "extern \"C\" __device__ void op(int a) {}";
return "extern \"C\" __device__ void op(int* a) {(*a)++;}";
case cccl_type_enum::UINT32:
return "extern \"C\" __device__ void op(unsigned int a) {}";
return "extern \"C\" __device__ void op(unsigned int* a) {(*a)++;}";
case cccl_type_enum::INT64:
return "extern \"C\" __device__ void op(long long a) {}";
return "extern \"C\" __device__ void op(long long* a) {(*a)++;}";
case cccl_type_enum::UINT64:
return "extern \"C\" __device__ void op(unsigned long long a) {}";
return "extern \"C\" __device__ void op(unsigned long long* a) {(*a)++;}";
default:
throw std::runtime_error("Unsupported type");
}
Expand Down
44 changes: 36 additions & 8 deletions c/test/test_for.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#include <cuda_runtime.h>

#include <algorithm>

#include "c2h.h"
#include <cccl/c/for.h>

Expand Down Expand Up @@ -45,10 +47,21 @@ TEMPLATE_LIST_TEST_CASE("for works with integral types", "[for]", integral_types

operation_t op = make_operation("op", get_for_op(get_type_info<TestType>().type));
std::vector<TestType> input(num_items, TestType(1));

pointer_t<TestType> input_ptr(input);

for_each(input_ptr, num_items, op);

// Copy back input array
input = input_ptr;
bool all_match = true;
std::for_each(input.begin(), input.end(), [&](auto v) {
if (v != 2)
{
all_match = false;
}
});

REQUIRE(all_match);
}

struct pair
Expand All @@ -64,13 +77,25 @@ TEST_CASE("for works with custom types", "[for]")
operation_t op = make_operation("op",
R"XXX(
struct pair { short a; size_t b; };
extern "C" __device__ void op(pair a) {}
extern "C" __device__ void op(pair* a) {a->a++; a->b++;}
)XXX");

std::vector<pair> input(num_items, pair{short(1), size_t(1)});
pointer_t<pair> input_ptr(input);

for_each(input_ptr, num_items, op);

// Copy back input array
input = input_ptr;
bool all_match = true;
std::for_each(input.begin(), input.end(), [&](auto v) {
if (v.a != 2 || v.b != 2)
{
all_match = false;
}
});

REQUIRE(all_match);
}

struct invocation_counter_state_t
Expand All @@ -87,13 +112,13 @@ TEST_CASE("for works with stateful operators", "[for]")
"op",
R"XXX(
struct invocation_counter_state_t { int* d_counter; };
extern "C" __device__ void op(invocation_counter_state_t* state, int a) {
atomicAdd(state->d_counter, 1);
extern "C" __device__ void op(invocation_counter_state_t* state, int* a) {
atomicAdd(state->d_counter, *a);
}
)XXX",
op_state);

const std::vector<int> input = generate<int>(num_items);
std::vector<int> input(num_items, 1);
pointer_t<int> input_ptr(input);

for_each(input_ptr, num_items, op);
Expand Down Expand Up @@ -123,13 +148,13 @@ struct large_state_t
int* d_counter;
int y, z, a;
};
extern "C" __device__ void op(large_state_t* state, int a) {
atomicAdd(state->d_counter, 1);
extern "C" __device__ void op(large_state_t* state, int* a) {
atomicAdd(state->d_counter, *a);
}
)XXX",
op_state);

const std::vector<int> input = generate<int>(num_items);
std::vector<int> input(num_items, 1);
pointer_t<int> input_ptr(input);

for_each(input_ptr, num_items, op);
Expand All @@ -144,6 +169,8 @@ struct constant_iterator_state_t
T value;
};

// TODO:
/*
TEST_CASE("for works with iterators", "[for]")
{
const int num_items = GENERATE(1, 42, take(4, random(1 << 12, 1 << 16)));
Expand Down Expand Up @@ -174,3 +201,4 @@ extern "C" __device__ void op(invocation_counter_state_t* state, int a) {
const int invocation_count = counter[0];
REQUIRE(invocation_count == num_items);
}
*/

0 comments on commit 6fb017e

Please sign in to comment.