Skip to content

Commit

Permalink
MPS GEMM Backward (#257)
Browse files Browse the repository at this point in the history
* implement and tests

* unnecessary header import

* still need to figure out a way to add dw and h dynamically

* separate h, dw, dbias calculation

* test for no h, no dw
  • Loading branch information
weiyanlin117 authored Aug 10, 2023
1 parent 239ff5c commit 664a61d
Show file tree
Hide file tree
Showing 2 changed files with 1,093 additions and 1 deletion.
170 changes: 169 additions & 1 deletion lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "nnc/ccv_nnc.h"
#include "nnc/ccv_nnc_easy.h"
#include "nnc/ccv_nnc_internal.h"
#include <Foundation/Foundation.h>
#ifdef HAVE_MPS
#include "nnc/mps/ccv_nnc_mps.h"
#endif
Expand Down Expand Up @@ -352,7 +353,174 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
}

static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context)
{
{
// inputs: gradient g, forw prop input a, [w]
// outputs: output gradient h, weight updates dw, bias updates bias
assert(input_size >= 2 && output_size >= 2);
const ccv_nnc_tensor_view_t* g = (const ccv_nnc_tensor_view_t*)inputs[0];
ccv_nnc_tensor_view_t* dw = output_size > 1 ? (ccv_nnc_tensor_view_t*)outputs[1] : 0;
ccv_nnc_tensor_view_t* bias = output_size > 2 ? (ccv_nnc_tensor_view_t*)outputs[2] : 0;

const ccv_nnc_tensor_view_t* a = input_size > 1 ? (const ccv_nnc_tensor_view_t*)inputs[1] : 0;
ccv_nnc_tensor_view_t* h = (ccv_nnc_tensor_view_t*)outputs[0];
const ccv_nnc_tensor_view_t* w = input_size > 2 ? (const ccv_nnc_tensor_view_t*)inputs[2] : 0;

assert(!bias || (bias->info.dim[1] == 0 || bias->info.dim[2] == 0 || bias->info.dim[3] == 0)); // // It is a 2-d or 3-d array
const int is_transpose_a = a ? ccv_nnc_is_matrix_transpose(a->info, cmd.info.blas.transpose_a) : 0;
const int is_transpose_w = w ? ccv_nnc_is_matrix_transpose(w->info, cmd.info.blas.transpose_b) : 0;

@autoreleasepool {
MPSCommandBuffer* command_buffer = ccv_nnc_stream_context_start_mps_command_buffer(stream_context);

if (h) {
assert(w); // when calculate h, w must exist
// [output gradient]
ccv_nnc_mps_graph_key_t key = ccv_nnc_mps_graph_key_new(cmd, hint, flags, (ccv_nnc_tensor_t*[]){ g, w }, 2, (ccv_nnc_tensor_t*[]){ h }, 1);
int indices[1];

MPSGraphExecutable* executable = ccv_nnc_mps_graph_executable_cache(key, indices, ^void (MPSGraph* graph, NSMutableArray<MPSGraphTensor*>* inputTensors, NSMutableArray<MPSGraphShapedType*>* inputShapedTypes, NSMutableArray<MPSGraphTensor*>* resultTensors) {
MPSGraphTensor* mps_input_g;
MPSGraphTensor* mps_g = ccv_nnc_mps_graph_tensor_input(graph, g, g->info.dim, g->stride, &mps_input_g);
[inputTensors addObject:mps_input_g];
MPSGraphShapedType* mps_g_shape = ccv_nnc_mps_graph_tensor_input_shape(g, g->info.dim, g->stride);
[inputShapedTypes addObject:mps_g_shape];

MPSGraphTensor* mps_input_w;
MPSGraphTensor* mps_w = ccv_nnc_mps_graph_tensor_input(graph, w, w->info.dim, w->stride, &mps_input_w);
[inputTensors addObject:mps_input_w];
MPSGraphShapedType* mps_w_shape = ccv_nnc_mps_graph_tensor_input_shape(w, w->info.dim, w->stride);
[inputShapedTypes addObject:mps_w_shape];

if (!is_transpose_w)
mps_w = [graph transposeTensor:mps_w dimension:-2 withDimension:-1 name:nil];

MPSGraphShapedType* mps_h_target_shape = ccv_nnc_mps_graph_tensor_input_shape(h, h->info.dim, h->stride);

MPSGraphTensor* mps_h = [graph matrixMultiplicationWithPrimaryTensor:mps_g secondaryTensor:mps_w name:nil];
if (is_transpose_a)
mps_h = [graph transposeTensor:mps_h dimension:-2 withDimension:-1 name:nil];

const NSUInteger mps_h_nd = mps_h.shape.count;
const NSUInteger h_target_nd = mps_h_target_shape.shape.count;

// if target h nd smaller than current mps_h_nd (for example, doing batch), mps_h needs to be reduced
if (h_target_nd < mps_h_nd) {
NSMutableArray<NSNumber*>* h_target_shape = mps_h_target_shape.shape.mutableCopy;
NSMutableArray<NSNumber*>* axes = [NSMutableArray new];

for ( int i = 0; i < mps_h_nd - h_target_nd; i++) {
[h_target_shape insertObject:@(1) atIndex:0]; // [1,..,1,N]
}

int i;
for (i = 0; i < mps_h_nd; i++) {
if (mps_h.shape[i].integerValue != h_target_shape[i].integerValue)
[axes addObject:@(i)];
}
mps_h = [graph reductionSumWithTensor:mps_h axes:axes name:nil];
}
[resultTensors addObject:mps_h];
});
MPSGraphTensorData* data_g = ccv_nnc_mps_graph_tensor_data(g, g->info.dim, g->stride);
MPSGraphTensorData* data_w = ccv_nnc_mps_graph_tensor_data(w, w->info.dim, w->stride);
MPSGraphTensorData* data[] = {data_g, data_w};
ccv_nnc_mps_graph_executable_result(executable, command_buffer, @[data[indices[0]], data[indices[1]]], &h, (int*[]){ h->info.dim }, (int*[]){ h->stride }, 1);
}

if (dw) {
assert(a); // when calculate dw, a must exist

// [weight updates]
ccv_nnc_mps_graph_key_t dw_key = ccv_nnc_mps_graph_key_new(cmd, hint, flags, (ccv_nnc_tensor_t*[]){ g, a }, 2, (ccv_nnc_tensor_t*[]){ dw }, 1);
int dw_indices[2];

MPSGraphExecutable* executable_dw = ccv_nnc_mps_graph_executable_cache(dw_key, dw_indices, ^void (MPSGraph* graph, NSMutableArray<MPSGraphTensor*>* inputTensors, NSMutableArray<MPSGraphShapedType*>* inputShapedTypes, NSMutableArray<MPSGraphTensor*>* resultTensors) {
MPSGraphTensor* mps_input_g;
MPSGraphTensor* mps_g = ccv_nnc_mps_graph_tensor_input(graph, g, g->info.dim, g->stride, &mps_input_g);
[inputTensors addObject:mps_input_g];
MPSGraphShapedType* mps_g_shape = ccv_nnc_mps_graph_tensor_input_shape(g, g->info.dim, g->stride);
[inputShapedTypes addObject:mps_g_shape];

MPSGraphTensor* mps_input_a;
MPSGraphTensor* mps_a = ccv_nnc_mps_graph_tensor_input(graph, a, a->info.dim, a->stride, &mps_input_a);
[inputTensors addObject:mps_input_a];
MPSGraphShapedType* mps_a_shape = ccv_nnc_mps_graph_tensor_input_shape(a, a->info.dim, a->stride);
[inputShapedTypes addObject:mps_a_shape];
if (!is_transpose_a)
mps_a = [graph transposeTensor:mps_a dimension:-2 withDimension:-1 name:nil];

MPSGraphShapedType* mps_dw_target_shape = ccv_nnc_mps_graph_tensor_input_shape(dw, dw->info.dim, dw->stride);

MPSGraphTensor* mps_dw = [graph matrixMultiplicationWithPrimaryTensor:mps_a secondaryTensor:mps_g name:nil];
if (is_transpose_w)
mps_dw = [graph transposeTensor:mps_dw dimension:-2 withDimension:-1 name:nil];

const NSUInteger mps_dw_nd = mps_dw.shape.count;
const NSUInteger dw_target_nd = mps_dw_target_shape.shape.count;

// if target dw nd smaller than current mupltiplication nd (like we are doing batch), mps_dw needs to be reduced
if ( dw_target_nd < mps_dw_nd ) {
NSMutableArray<NSNumber*>* dw_target_shape = mps_dw_target_shape.shape.mutableCopy;
NSMutableArray<NSNumber*>* axes = [NSMutableArray new];
for ( int i = 0; i < mps_dw_nd - dw_target_nd; i++) {
[dw_target_shape insertObject:@(1) atIndex:0]; // [1,..,1,N]
}

int i;
for (i = 0; i < mps_dw_nd; i++) {
if (mps_dw.shape[i].integerValue != dw_target_shape[i].integerValue)
[axes addObject:@(i)];
}
mps_dw = [graph reductionSumWithTensor:mps_dw axes:axes name:nil];
}

[resultTensors addObject:mps_dw];

});
MPSGraphTensorData* data_g = ccv_nnc_mps_graph_tensor_data(g, g->info.dim, g->stride);
MPSGraphTensorData* data_a = ccv_nnc_mps_graph_tensor_data(a, a->info.dim, a->stride);
ccv_nnc_mps_graph_executable_result(executable_dw, command_buffer, @[data_g, data_a], &dw , (int*[]){ dw->info.dim }, (int*[]){ dw->stride }, 1);
}

if (bias) {
// [bias updates]
ccv_nnc_mps_graph_key_t db_key = ccv_nnc_mps_graph_key_new(cmd, hint, flags, (ccv_nnc_tensor_t*[]){ g }, 1, (ccv_nnc_tensor_t*[]){ bias }, 1);
int db_indices[1];

MPSGraphExecutable* executable_db = ccv_nnc_mps_graph_executable_cache(db_key, db_indices, ^void (MPSGraph* graph, NSMutableArray<MPSGraphTensor*>* inputTensors, NSMutableArray<MPSGraphShapedType*>* inputShapedTypes, NSMutableArray<MPSGraphTensor*>* resultTensors) {
MPSGraphTensor* mps_input_g;
MPSGraphTensor* mps_g = ccv_nnc_mps_graph_tensor_input(graph, g, g->info.dim, g->stride, &mps_input_g);
[inputTensors addObject:mps_input_g];
MPSGraphShapedType* mps_g_shape = ccv_nnc_mps_graph_tensor_input_shape(g, g->info.dim, g->stride);
[inputShapedTypes addObject:mps_g_shape];

MPSGraphShapedType* mps_bias_shape = ccv_nnc_mps_graph_tensor_input_shape(bias, bias->info.dim, bias->stride);

NSMutableArray<NSNumber*>* bias_target_shape = mps_bias_shape.shape.mutableCopy;
NSMutableArray<NSNumber*>* axes = [NSMutableArray new];
const int g_nd = ccv_nnc_tensor_nd(g->info.dim);
const int bias_nd = ccv_nnc_tensor_nd(bias->info.dim);

// make bias_target_shape has same dim as g before finding reduce axis
for ( int i = 0; i < g_nd - bias_nd; i++) {
[bias_target_shape insertObject:@(1) atIndex:0]; // [1,..,1,N]
}

int i;
for (i = 0; i < g_nd; i++) {
if (g->info.dim[i] != bias_target_shape[i].integerValue)
[axes addObject:@(i)];
}
MPSGraphTensor* mps_db = [graph reductionSumWithTensor:mps_g axes:axes name:nil];
[resultTensors addObject:mps_db];
});
MPSGraphTensorData* data_g = ccv_nnc_mps_graph_tensor_data(g, g->info.dim, g->stride);
ccv_nnc_mps_graph_executable_result(executable_db, command_buffer, @[data_g], &bias , (int*[]){ bias->info.dim }, (int*[]){ bias->info.dim }, 1);
}

ccv_nnc_stream_context_finish_mps_command_buffer(stream_context, command_buffer);
}

return CCV_NNC_EXEC_SUCCESS;
}

Expand Down
Loading

0 comments on commit 664a61d

Please sign in to comment.