Skip to content

Commit

Permalink
rewrite reduce ops
Browse files Browse the repository at this point in the history
  • Loading branch information
hubertdelajonquieresonos committed Jul 19, 2024
1 parent 87e13d2 commit c43f8e2
Showing 1 changed file with 56 additions and 77 deletions.
133 changes: 56 additions & 77 deletions metal/src/kernels/nn/nn_ops.metal
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ struct Prod {

template<typename F, typename Op>
[[kernel]] void reduce_nd3(
device const F *input,
device F *output,
device const void *input_b,
device void *output_b,
constant const size_t input_shape[3],
constant const size_t input_strides[3],
constant const size_t output_strides[3],
Expand All @@ -97,6 +97,9 @@ template<typename F, typename Op>
uint tpsg[[threads_per_simdgroup]]
) {

device const F *input = (device const F *)input_b;
device F *output = (device F *)output_b;

Op op = Op();

size_t reduce_dim = input_shape[1];
Expand All @@ -120,26 +123,49 @@ template<typename F, typename Op>
}
}

typedef decltype(reduce_nd3<float, Prod<float>>) reduce_nd3_t;

#define INSTANTIATE_REDUCE(name, op, tname, type) \
template [[host_name("nn_ops::reduce_" #name "_nd3_" #tname)]] \
[[kernel]] reduce_nd3_t reduce_nd3<type, op<type>>;


INSTANTIATE_REDUCE(mean_of_squares, MeanOfSquares, f32, float)
INSTANTIATE_REDUCE(mean_of_squares, MeanOfSquares, f16, half)
INSTANTIATE_REDUCE(sum, Sum, f32, float)
INSTANTIATE_REDUCE(sum, Sum, f16, half)
INSTANTIATE_REDUCE(min, Min, f32, float)
INSTANTIATE_REDUCE(min, Min, f16, half)
INSTANTIATE_REDUCE(max, Max, f32, float)
INSTANTIATE_REDUCE(max, Max, f16, half)
INSTANTIATE_REDUCE(prod, Prod, f32, float)
INSTANTIATE_REDUCE(prod, Prod, f16, half)


template<typename F>
[[kernel]] void rms_norm_nd3(
device const F *input,
device const F & eps,
device F *output,
device const void *input_b,
constant void * eps_b,
device void *output_b,
constant const size_t shape[3],
constant const size_t strides[3],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint tpsg[[threads_per_simdgroup]]
) {

device const F* input = (device const F*) input_b;
F eps = ((constant F *)eps_b)[0];
device F * output = (device F*) output_b;

size_t dim = shape[1];

size_t base_idx = tgpig.x * strides[2]
+ tgpig.z * strides[0];

float partial_acc = 0.0;
for (size_t i = tiisg; i < dim; i += tpsg) {
F el = input[base_idx + i * strides[1]];
float el = static_cast<float>(input[base_idx + i * strides[1]]);
partial_acc += el * el;
}
float mean_of_squares = simd_sum(partial_acc) / static_cast<float>(dim);
Expand All @@ -152,6 +178,12 @@ template<typename F>
}
}

typedef decltype(rms_norm_nd3<float>) rms_norm_nd3_t;

template [[host_name("nn_ops::rms_norm_nd3_f32")]] [[kernel]] rms_norm_nd3_t rms_norm_nd3<float>;
template [[host_name("nn_ops::rms_norm_nd3_f16")]] [[kernel]] rms_norm_nd3_t rms_norm_nd3<half>;


struct Sigmoid {
template <typename T>
T operator()(T x) {
Expand All @@ -161,24 +193,35 @@ struct Sigmoid {
};

template<typename T>
[[kernel]] void silu(device const T *input[ [buffer(0)]],
device T *output [[buffer(1)]],
[[kernel]] void silu(device const void *input_b [[buffer(0)]],
device void *output_b [[buffer(1)]],
uint tpig[[thread_position_in_grid]]) {
device const T *input = (device const T *)input_b;
device T *output = (device T *)output_b;

output[tpig] = Sigmoid()(input[tpig]) * input[tpig];
}

typedef decltype(silu<float>) silu_t;

template [[host_name("nn_ops::silu_f32")]] [[kernel]] silu_t silu<float>;
template [[host_name("nn_ops::silu_f16")]] [[kernel]] silu_t silu<half>;


template<typename F>
[[kernel]] void softmax_nd3(
device const F *input,
device F *output,
device const void *input_b,
device void *output_b,
constant const size_t shape[3],
constant const size_t strides[3],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint tpsg[[threads_per_simdgroup]]
) {

device const F *input = (device const F *)input_b;
device F *output = (device F *)output_b;

size_t dim = shape[1];

size_t base_idx = tgpig.x * strides[2]
Expand Down Expand Up @@ -214,73 +257,9 @@ template<typename F>
}
}

#define INSTANTIATE_REDUCE(name, op, tname, type) \
template [[host_name("nn_ops::reduce_" #name "_nd3_" #tname)]] \
[[kernel]] void reduce_nd3<type, op<type>>( \
device const type *input, \
device type *output, \
constant const size_t input_shape[3], \
constant const size_t input_strides[3], \
constant const size_t output_strides[3], \
uint3 tgpig[[threadgroup_position_in_grid]], \
uint tiisg[[thread_index_in_simdgroup]], \
uint tpsg[[threads_per_simdgroup]] \
);


INSTANTIATE_REDUCE(mean_of_squares, MeanOfSquares, f32, float)
INSTANTIATE_REDUCE(mean_of_squares, MeanOfSquares, f16, half)
INSTANTIATE_REDUCE(sum, Sum, f32, float)
INSTANTIATE_REDUCE(sum, Sum, f16, half)
INSTANTIATE_REDUCE(min, Min, f32, float)
INSTANTIATE_REDUCE(min, Min, f16, half)
INSTANTIATE_REDUCE(max, Max, f32, float)
INSTANTIATE_REDUCE(max, Max, f16, half)
INSTANTIATE_REDUCE(prod, Prod, f32, float)
INSTANTIATE_REDUCE(prod, Prod, f16, half)

#define INSTANTIATE_SOFTMAX(tname, type) \
template [[host_name("nn_ops::softmax_nd3_" #tname)]] \
[[kernel]] void softmax_nd3<type>( \
device const type *input, \
device type *output, \
constant const size_t shape[3], \
constant const size_t strides[3], \
uint3 tgpig[[threadgroup_position_in_grid]], \
uint tiisg[[thread_index_in_simdgroup]], \
uint tpsg[[threads_per_simdgroup]] \
);

INSTANTIATE_SOFTMAX(f32, float)
INSTANTIATE_SOFTMAX(f16, half)

#define INSTANTIATE_RMS_NORM(tname, type) \
template [[host_name("nn_ops::rms_norm_nd3_" #tname)]] \
[[kernel]] void rms_norm_nd3<type>( \
device const type *input, \
device const type &eps, \
device type *output, \
constant const size_t shape[3], \
constant const size_t strides[3], \
uint3 tgpig[[threadgroup_position_in_grid]], \
uint tiisg[[thread_index_in_simdgroup]], \
uint tpsg[[threads_per_simdgroup]] \
);

INSTANTIATE_RMS_NORM(f32, float)
INSTANTIATE_RMS_NORM(f16, half)

#define INSTANTIATE_SILU(tname, type) \
template [[host_name("nn_ops::silu_" #tname)]] \
[[kernel]] void silu<type>( \
device const type *input [[buffer(0)]], \
device type *output [[buffer(1)]], \
uint tpig[[thread_position_in_grid]] \
);

INSTANTIATE_SILU(f32, float)
INSTANTIATE_SILU(f16, half)

typedef decltype(softmax_nd3<float>) softmax_nd3_t;

template [[host_name("nn_ops::softmax_nd3_f32")]] [[kernel]] softmax_nd3_t softmax_nd3<float>;
template [[host_name("nn_ops::softmax_nd3_f16")]] [[kernel]] softmax_nd3_t softmax_nd3<half>;


0 comments on commit c43f8e2

Please sign in to comment.