Skip to content

Commit

Permalink
SCALAR MUL Backward (#256)
Browse files Browse the repository at this point in the history
* added tests

* resolve tests comments

* REQUIRE_EQ_WITH_TOLERANCE
  • Loading branch information
weiyanlin117 authored Aug 8, 2023
1 parent 6f2be6a commit 3322dfa
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 0 deletions.
68 changes: 68 additions & 0 deletions lib/nnc/cmd/blas/mps/ccv_nnc_mul_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,65 @@ static int _ccv_nnc_scalar_mul_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_
return CCV_NNC_EXEC_SUCCESS;
}

static int _ccv_nnc_scalar_mul_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context)
{
const float p = cmd.info.blas.a[0];
ccv_nnc_tensor_view_t* const c = (ccv_nnc_tensor_view_t*)outputs[0];

if (inputs[0] == 0)
{
@autoreleasepool {
MPSCommandBuffer* command_buffer = ccv_nnc_stream_context_start_mps_command_buffer(stream_context);
ccv_nnc_mps_graph_key_t key = ccv_nnc_mps_graph_key_new(cmd, hint, flags, inputs, input_size, outputs, output_size);
int indices[1];
MPSGraphExecutable* executable = ccv_nnc_mps_graph_executable_cache(key, indices, ^void (MPSGraph* graph, NSMutableArray<MPSGraphTensor*>* inputTensors, NSMutableArray<MPSGraphShapedType*>* inputShapedTypes, NSMutableArray<MPSGraphTensor*>* resultTensors) {
MPSGraphShapedType* mps_c_shape = ccv_nnc_mps_graph_tensor_input_shape(c, c->info.dim, c->stride);

MPSGraphTensor* mps_c = [graph constantWithScalar:p shape:mps_c_shape.shape dataType:ccv_nnc_mps_datatype(c->info.datatype)];
[resultTensors addObject:mps_c];
});
ccv_nnc_mps_graph_executable_result(executable, command_buffer, @[], &c, (int*[]){ c->info.dim }, (int*[]){ c->stride }, 1);
ccv_nnc_stream_context_finish_mps_command_buffer(stream_context, command_buffer);
}
return CCV_NNC_EXEC_SUCCESS;
}

const ccv_nnc_tensor_view_t* const a = (const ccv_nnc_tensor_view_t*)inputs[0];
@autoreleasepool {
MPSCommandBuffer* command_buffer = ccv_nnc_stream_context_start_mps_command_buffer(stream_context);
if (p == 1)
{
MPSGraph* graph = [MPSGraph new];
graph.options = MPSGraphOptionsSynchronizeResults;
MPSGraphTensor* mps_input_a;
MPSGraphTensor* mps_a = ccv_nnc_mps_graph_tensor_input(graph, a, a->info.dim, a->stride, &mps_input_a);
MPSGraphTensorData* data_a = ccv_nnc_mps_graph_tensor_data(a, a->info.dim, a->stride);
if (mps_a != mps_input_a)
ccv_nnc_mps_graph_result(graph, command_buffer, @{mps_input_a: data_a}, mps_a, c, c->info.dim, c->stride);
else
ccv_nnc_mps_export_data(data_a, command_buffer, c, c->info.dim, c->stride);
[graph release];
} else {
ccv_nnc_mps_graph_key_t key = ccv_nnc_mps_graph_key_new(cmd, hint, flags, inputs, input_size, outputs, output_size);
int indices[1];
MPSGraphExecutable* executable = ccv_nnc_mps_graph_executable_cache(key, indices, ^void (MPSGraph* graph, NSMutableArray<MPSGraphTensor*>* inputTensors, NSMutableArray<MPSGraphShapedType*>* inputShapedTypes, NSMutableArray<MPSGraphTensor*>* resultTensors) {
MPSGraphTensor* mps_input_a;
MPSGraphTensor* mps_a = ccv_nnc_mps_graph_tensor_input(graph, a, a->info.dim, a->stride, &mps_input_a);
[inputTensors addObject:mps_input_a];
MPSGraphShapedType* mps_a_shape = ccv_nnc_mps_graph_tensor_input_shape(a, a->info.dim, a->stride);
[inputShapedTypes addObject:mps_a_shape];
MPSGraphTensor* mps_p = [graph constantWithScalar:p dataType:ccv_nnc_mps_datatype(a->info.datatype)];
MPSGraphTensor* mps_c = [graph multiplicationWithPrimaryTensor:mps_a secondaryTensor:mps_p name:nil];
[resultTensors addObject:mps_c];
});
MPSGraphTensorData* data_a = ccv_nnc_mps_graph_tensor_data(a, a->info.dim, a->stride);
ccv_nnc_mps_graph_executable_result(executable, command_buffer, @[data_a], &c, (int*[]){ c->info.dim }, (int*[]){ c->stride }, 1);
}
ccv_nnc_stream_context_finish_mps_command_buffer(stream_context, command_buffer);
}
return CCV_NNC_EXEC_SUCCESS;
}

REGISTER_COMMAND_BACKEND(CCV_NNC_SCALAR_MUL_FORWARD, CCV_NNC_BACKEND_MPS)(ccv_nnc_cmd_backend_registry_t* const registry)
{
registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_CHWN;
Expand All @@ -346,3 +405,12 @@ static int _ccv_nnc_scalar_mul_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_
registry->algorithms = 1;
registry->exec = _ccv_nnc_scalar_mul_forw;
}

REGISTER_COMMAND_BACKEND(CCV_NNC_SCALAR_MUL_BACKWARD, CCV_NNC_BACKEND_MPS)(ccv_nnc_cmd_backend_registry_t* const registry)
{
registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_CHWN;
registry->tensor_datatypes = CCV_32F | CCV_16F;
registry->tensor_memory = CCV_TENSOR_GPU_MEMORY;
registry->algorithms = 1;
registry->exec = _ccv_nnc_scalar_mul_back;
}
2 changes: 2 additions & 0 deletions lib/nnc/cmd/ccv_nnc_cmd.inc
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ void _register_command_CCV_NNC_ADD_BACKWARD_backend_CCV_NNC_BACKEND_MPS(ccv_nnc_
void _register_command_CCV_NNC_MUL_FORWARD_backend_CCV_NNC_BACKEND_MPS(ccv_nnc_cmd_backend_registry_t* const registry);
void _register_command_CCV_NNC_MUL_BACKWARD_backend_CCV_NNC_BACKEND_MPS(ccv_nnc_cmd_backend_registry_t* const registry);
void _register_command_CCV_NNC_SCALAR_MUL_FORWARD_backend_CCV_NNC_BACKEND_MPS(ccv_nnc_cmd_backend_registry_t* const registry);
void _register_command_CCV_NNC_SCALAR_MUL_BACKWARD_backend_CCV_NNC_BACKEND_MPS(ccv_nnc_cmd_backend_registry_t* const registry);
void _register_command_CCV_NNC_CONVOLUTION_FORWARD_backend_CCV_NNC_BACKEND_MPS(ccv_nnc_cmd_backend_registry_t* const registry);
void _register_command_CCV_NNC_CONVOLUTION_BACKWARD_backend_CCV_NNC_BACKEND_MPS(ccv_nnc_cmd_backend_registry_t* const registry);
void _register_command_CCV_NNC_EWSUM_FORWARD_backend_CCV_NNC_BACKEND_MPS(ccv_nnc_cmd_backend_registry_t* const registry);
Expand Down Expand Up @@ -1004,6 +1005,7 @@ static inline void _ccv_nnc_cmd_init(void)
_register_command_CCV_NNC_MUL_FORWARD_backend_CCV_NNC_BACKEND_MPS(&(init_map[68].backends[6]));
_register_command_CCV_NNC_MUL_BACKWARD_backend_CCV_NNC_BACKEND_MPS(&(init_map[69].backends[6]));
_register_command_CCV_NNC_SCALAR_MUL_FORWARD_backend_CCV_NNC_BACKEND_MPS(&(init_map[124].backends[6]));
_register_command_CCV_NNC_SCALAR_MUL_BACKWARD_backend_CCV_NNC_BACKEND_MPS(&(init_map[125].backends[6]));
_register_command_CCV_NNC_CONVOLUTION_FORWARD_backend_CCV_NNC_BACKEND_MPS(&(init_map[104].backends[6]));
_register_command_CCV_NNC_CONVOLUTION_BACKWARD_backend_CCV_NNC_BACKEND_MPS(&(init_map[105].backends[6]));
_register_command_CCV_NNC_EWSUM_FORWARD_backend_CCV_NNC_BACKEND_MPS(&(init_map[40].backends[6]));
Expand Down
76 changes: 76 additions & 0 deletions test/int/nnc/mpsdnn.tests.c
Original file line number Diff line number Diff line change
Expand Up @@ -2276,4 +2276,80 @@ TEST_CASE("broadcasting semantics for mul backward (no input grad) for a")
ccv_nnc_tensor_free(gdb);
}

TEST_CASE("mps scalarmul forward")
{
GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_SCALAR_MUL_BACKWARD, CCV_NNC_BACKEND_MPS) &&
ccv_nnc_cmd_ok(CCV_NNC_SCALAR_MUL_FORWARD, CCV_NNC_BACKEND_MPS));

ccv_nnc_tensor_t* const x = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 4), 0);
ccv_nnc_tensor_t* const gx = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, 4), 0);

dsfmt_t dsfmt;
dsfmt_init_gen_rand(&dsfmt, 0);
int i;
for (i = 0; i < 4; i++)
x->data.f32[i] = dsfmt_genrand_open_close(&dsfmt);
ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(x), TENSOR_LIST(gx), 0);

ccv_nnc_tensor_t* const gy = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, 4), 0);

ccv_nnc_cmd_exec(CMD_SCALAR_MUL_FORWARD(1.1), ccv_nnc_no_hint, 0, TENSOR_LIST(gx), TENSOR_LIST(gy), 0);

ccv_nnc_tensor_t* const y = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 4), 0);
ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(gy), TENSOR_LIST(y), 0);
for (i = 0; i < 4; i++) {
REQUIRE_EQ_WITH_TOLERANCE(x->data.f32[i] * 1.1, y->data.f32[i], 1e-5, "scalarmul forward cy has to be 1.1 * x");
}

ccv_nnc_tensor_free(x);
ccv_nnc_tensor_free(gx);
ccv_nnc_tensor_free(gy);
ccv_nnc_tensor_free(y);
}

TEST_CASE("mps scalarmul backward")
{
GUARD_ELSE_RETURN(
ccv_nnc_cmd_ok(CCV_NNC_SCALAR_MUL_FORWARD, CCV_NNC_BACKEND_MPS));

ccv_nnc_tensor_t* const y = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 4), 0);

dsfmt_t dsfmt;
dsfmt_init_gen_rand(&dsfmt, 0);
int i;
for (i = 0; i < 4; i++)
y->data.f32[i] = dsfmt_genrand_open_close(&dsfmt);
ccv_nnc_tensor_t* const gy = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, 4), 0);
ccv_nnc_tensor_t* const gdx = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, 4), 0);
ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(y), TENSOR_LIST(gy), 0);
ccv_nnc_cmd_exec(CMD_SCALAR_MUL_BACKWARD(1.1), ccv_nnc_no_hint, 0, TENSOR_LIST(gy), TENSOR_LIST(gdx), 0);

ccv_nnc_tensor_t* const dx = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 4), 0);
ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(gdx), TENSOR_LIST(dx), 0);

for (i = 0; i < 4; i++) {
REQUIRE_EQ_WITH_TOLERANCE(dx->data.f32[i], y->data.f32[i] * 1.1, 1e-5, "scalarmul backward dx has to be 1.1 * dy");
}

ccv_nnc_tensor_free(y);
ccv_nnc_tensor_free(gy);
ccv_nnc_tensor_free(gdx);
ccv_nnc_tensor_free(dx);
}

TEST_CASE("mps scalarmul backward, no input")
{
GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_SCALAR_MUL_BACKWARD, CCV_NNC_BACKEND_MPS));

ccv_nnc_tensor_t* const gdx = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, 4), 0);
ccv_nnc_cmd_exec(CMD_SCALAR_MUL_BACKWARD(1.1), ccv_nnc_no_hint, 0, TENSOR_LIST(0), TENSOR_LIST(gdx), 0);
ccv_nnc_tensor_t* const dx = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 4), 0);
ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(gdx), TENSOR_LIST(dx), 0);

for (int i = 0; i < 4; i++)
REQUIRE_EQ_WITH_TOLERANCE(dx->data.f32[i], 1.1, 1e-5, "scalar mul backward without input should be 1.1 ");
ccv_nnc_tensor_free(gdx);
ccv_nnc_tensor_free(dx);
}

#include "case_main.h"

0 comments on commit 3322dfa

Please sign in to comment.