Skip to content

Commit

Permalink
Make MPS GEMM more flexible on batch stride.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Aug 18, 2024
1 parent 8d6197e commit ddd3f97
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 234 deletions.
162 changes: 38 additions & 124 deletions lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
const int is_transpose_w = ccv_nnc_is_matrix_transpose(w->info, cmd.info.blas.transpose_b);
int biasdim[CCV_NNC_MAX_DIM_ALLOC] = {0};
int biasstride[CCV_NNC_MAX_DIM_ALLOC] = {0};
int bias_batch_size = 1;
const int b_nd = ccv_nnc_tensor_nd(b->info.dim);
if (bias)
{
Expand All @@ -69,6 +70,8 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
memcpy(biasdim, bias->info.dim, sizeof(biasdim));
if (CCV_IS_TENSOR_VIEW(bias))
memcpy(biasstride, bias->stride, sizeof(biasstride));
for (i = 0; i < bias_nd - 2; i++)
bias_batch_size *= biasdim[i];
} else if (bias_nd == 2) {
biasdim[0] = bias->info.dim[0];
for (i = 1; i < b_nd - 1; i++)
Expand All @@ -81,6 +84,8 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
biasstride[i] = biasstride[0];
biasstride[b_nd - 1] = bias->stride[1];
}
for (i = 0; i < bias_nd - 1; i++)
bias_batch_size *= biasdim[i];
} else {
for (i = 0; i < b_nd - 1; i++)
biasdim[i] = 1;
Expand All @@ -91,6 +96,8 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
biasstride[i] = bias->info.dim[0] * bias->stride[0];
biasstride[b_nd - 1] = bias->stride[0];
}
for (i = 0; i < bias_nd - 1; i++)
bias_batch_size *= biasdim[i];
}
}
int* adim_r = adim;
Expand All @@ -114,11 +121,17 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
if (w_batch_size == 1 && b_batch_size > 1)
w_batch_inc = 0;
@autoreleasepool {
// Fake the astride at a_nd - 3. For this one, we have flexibility to change fo v2 GEMM kernels.
const int astride_a_nd_3 = astride[a_nd - 3];
// Only fake it if it is larger than the expected compact stride.
if (astride_a_nd_3 > astride[a_nd - 2] * adim[a_nd - 2])
astride[a_nd - 3] = astride[a_nd - 2] * adim[a_nd - 2];
const int is_contiguous =
(!CCV_IS_TENSOR_VIEW(a) || ccv_nnc_tensor_view_is_contiguous(adim, astride)) &&
(!CCV_IS_TENSOR_VIEW(w) || ccv_nnc_tensor_view_is_contiguous(w->info.dim, w->stride)) &&
(!CCV_IS_TENSOR_VIEW(b) || ccv_nnc_tensor_view_is_contiguous(b->info.dim, b->stride)) &&
(bias ? (!CCV_IS_TENSOR_VIEW(bias) || ccv_nnc_tensor_view_is_contiguous(bias->info.dim, bias->stride)) : 1);
astride[a_nd - 3] = astride_a_nd_3;

const int a_datatype = CCV_GET_DATA_TYPE(a->info.datatype) == CCV_QX ? ((a->info.datatype & 0xff) << 12) : a->info.datatype;
const int w_datatype = CCV_GET_DATA_TYPE(w->info.datatype) == CCV_QX ? ((w->info.datatype & 0xff) << 12) : w->info.datatype;
Expand Down Expand Up @@ -363,45 +376,15 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
.A_trans = (is_transpose_a ? 1 : 0),
.B_trans = (is_transpose_w ? 1 : 0),
.D_trans = 0,
.batched = is_batched,
.fused_bias = (bias ? 1 : 0),
.register_float = (is_upcast ? 1 : 0),

.batch_dims_a = { 0 },
.batch_dims_b = { 0 },
.batch_dims_d = { 0 },
.batch_dimension = b_batch_size,
.batch_stride_a = a_batch_size > 1 ? ccv_max(astride_a_nd_3, b_rows * w_rows) : 0,
.batch_stride_b = w_batch_size > 1 ? b_cols * w_rows : 0,
.batch_stride_c = b_batch_size > 1 ? b_rows * b_cols : 0,
.batch_stride_d = bias_batch_size > 1 ? b_cols : 0,
};
if (is_batched) {
// Create a null-terminated list of batch dimensions.
int A_batch_dim = a_nd - 2;
for (int i = 0; i < A_batch_dim; ++i) {
params.batch_dims_a[i] = adim[i];
}
if (A_batch_dim < CCV_NNC_MAX_DIM_ALLOC) {
params.batch_dims_a[A_batch_dim] = 0;
}

int B_batch_dim = w_nd - 2;
for (int i = 0; i < B_batch_dim; ++i) {
params.batch_dims_b[i] = w->info.dim[i];
}
if (B_batch_dim < CCV_NNC_MAX_DIM_ALLOC) {
params.batch_dims_b[B_batch_dim] = 0;
}

int D_batch_dim = 0;
if (bias) {
const int bias_nd = ccv_nnc_tensor_nd(bias->info.dim);
assert(bias_nd <= a_nd);
D_batch_dim = bias_nd == a_nd ? bias_nd - 2 : bias_nd - 1;
}
for (int i = 0; i < D_batch_dim; ++i) {
params.batch_dims_d[i] = biasdim[i];
}
if (D_batch_dim < CCV_NNC_MAX_DIM_ALLOC) {
params.batch_dims_d[D_batch_dim] = 0;
}
}
ccv_nnc_mfa_prepare_gemm(context, params);

// Creating a new command buffer has a >10 µs penalty CPU-side. Still
Expand Down Expand Up @@ -699,7 +682,6 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint

// NNC uses the convention B = A * W.
// MFA uses the convention C = A * B.
int is_batched = g_batch_size > 1;

ccv_nnc_mfa_context_t* context = ccv_nnc_default_mfa_context();
const int is_mfa_supported =
Expand Down Expand Up @@ -792,31 +774,14 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
.A_trans = 1,
.B_trans = (is_transpose_w ? 1 : 0),
.D_trans = 0,
.batched = is_batched,
.fused_bias = 0,

.batch_dims_a = { 0 },
.batch_dims_b = { 0 },
.batch_dims_d = { 0 },
.batch_dimension = g_batch_size,
.batch_stride_a = w_batch_size > 1 ? w_rows * w_cols : 0,
.batch_stride_b = g_batch_size > 1 ? g_rows * w_cols : 0,
.batch_stride_c = h_batch_size > 1 ? w_rows * g_rows : 0,
.batch_stride_d = 0,
};
if (is_batched) {
// Create a null-terminated list of batch dimensions.
int A_batch_dim = w_nd - 2;
for (int i = 0; i < A_batch_dim; ++i) {
params.batch_dims_a[i] = w->info.dim[i];
}
if (A_batch_dim < CCV_NNC_MAX_DIM_ALLOC) {
params.batch_dims_a[A_batch_dim] = 0;
}

int B_batch_dim = g_nd - 2;
for (int i = 0; i < B_batch_dim; ++i) {
params.batch_dims_b[i] = g->info.dim[i];
}
if (B_batch_dim < CCV_NNC_MAX_DIM_ALLOC) {
params.batch_dims_b[B_batch_dim] = 0;
}
}
ccv_nnc_mfa_prepare_gemm(context, params);
h_params = params;
} else {
Expand All @@ -828,31 +793,14 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
.A_trans = 0,
.B_trans = (is_transpose_w ? 0 : 1),
.D_trans = 0,
.batched = is_batched,
.fused_bias = 0,

.batch_dims_a = { 0 },
.batch_dims_b = { 0 },
.batch_dims_d = { 0 },
.batch_dimension = g_batch_size,
.batch_stride_a = g_batch_size > 1 ? g_rows * w_cols : 0,
.batch_stride_b = w_batch_size > 1 ? w_rows * w_cols : 0,
.batch_stride_c = h_batch_size > 1 ? g_rows * w_rows : 0,
.batch_stride_d = 0,
};
if (is_batched) {
// Create a null-terminated list of batch dimensions.
int A_batch_dim = g_nd - 2;
for (int i = 0; i < A_batch_dim; ++i) {
params.batch_dims_a[i] = g->info.dim[i];
}
if (A_batch_dim < CCV_NNC_MAX_DIM_ALLOC) {
params.batch_dims_a[A_batch_dim] = 0;
}

int B_batch_dim = w_nd - 2;
for (int i = 0; i < B_batch_dim; ++i) {
params.batch_dims_b[i] = w->info.dim[i];
}
if (B_batch_dim < CCV_NNC_MAX_DIM_ALLOC) {
params.batch_dims_b[B_batch_dim] = 0;
}
}
ccv_nnc_mfa_prepare_gemm(context, params);
h_params = params;
}
Expand All @@ -872,31 +820,14 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
.A_trans = 1,
.B_trans = (is_transpose_a ? 1 : 0),
.D_trans = 0,
.batched = is_batched,
.fused_bias = 0,

.batch_dims_a = { 0 },
.batch_dims_b = { 0 },
.batch_dims_d = { 0 },
.batch_dimension = g_batch_size,
.batch_stride_a = g_batch_size > 1 ? dw_cols * g_rows : 0,
.batch_stride_b = a_batch_size > 1 ? dw_rows * g_rows : 0,
.batch_stride_c = dw_batch_size > 1 ? dw_cols * dw_rows : 0,
.batch_stride_d = 0,
};
if (is_batched) {
// Create a null-terminated list of batch dimensions.
int A_batch_dim = a_nd - 2;
for (int i = 0; i < A_batch_dim; ++i) {
params.batch_dims_a[i] = a->info.dim[i];
}
if (A_batch_dim < CCV_NNC_MAX_DIM_ALLOC) {
params.batch_dims_a[A_batch_dim] = 0;
}

int B_batch_dim = g_nd - 2;
for (int i = 0; i < B_batch_dim; ++i) {
params.batch_dims_b[i] = g->info.dim[i];
}
if (B_batch_dim < CCV_NNC_MAX_DIM_ALLOC) {
params.batch_dims_b[B_batch_dim] = 0;
}
}
ccv_nnc_mfa_prepare_gemm(context, params);
dw_params = params;
} else {
Expand All @@ -908,31 +839,14 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
.A_trans = (is_transpose_a ? 0 : 1),
.B_trans = 0,
.D_trans = 0,
.batched = is_batched,
.fused_bias = 0,

.batch_dims_a = { 0 },
.batch_dims_b = { 0 },
.batch_dims_d = { 0 },
.batch_dimension = g_batch_size,
.batch_stride_a = a_batch_size > 1 ? dw_rows * g_rows : 0,
.batch_stride_b = g_batch_size > 1 ? dw_cols * g_rows : 0,
.batch_stride_c = dw_batch_size > 1 ? dw_rows * dw_cols : 0,
.batch_stride_d = 0,
};
if (is_batched) {
// Create a null-terminated list of batch dimensions.
int A_batch_dim = g_nd - 2;
for (int i = 0; i < A_batch_dim; ++i) {
params.batch_dims_a[i] = g->info.dim[i];
}
if (A_batch_dim < CCV_NNC_MAX_DIM_ALLOC) {
params.batch_dims_a[A_batch_dim] = 0;
}

int B_batch_dim = a_nd - 2;
for (int i = 0; i < B_batch_dim; ++i) {
params.batch_dims_b[i] = a->info.dim[i];
}
if (B_batch_dim < CCV_NNC_MAX_DIM_ALLOC) {
params.batch_dims_b[B_batch_dim] = 0;
}
}
ccv_nnc_mfa_prepare_gemm(context, params);
dw_params = params;
}
Expand Down
68 changes: 16 additions & 52 deletions lib/nnc/cmd/convolution/mps/ccv_nnc_conv_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,18 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
const int w_nd = ccv_nnc_tensor_nd(w->info.dim);
const int b_nd = ccv_nnc_tensor_nd(b->info.dim);
int is_batched = 0;
int a_batch_size;
int w_batch_size;
int b_batch_size;
if (use_mfa) {
int a_batch_size = a_nd < 4 ? 1 : adim[a_nd - 4];
a_batch_size = a_nd < 4 ? 1 : adim[a_nd - 4];
int i;
for (i = 0; i < a_nd - 4; i++)
a_batch_size *= adim[i];
int w_batch_size = w_nd < 5 ? 1 : w->info.dim[w_nd - 5];
w_batch_size = w_nd < 5 ? 1 : w->info.dim[w_nd - 5];
for (i = 0; i < w_nd - 5; i++)
w_batch_size *= w->info.dim[i];
int b_batch_size = b_nd < 4 ? 1 : b->info.dim[b_nd - 4];
b_batch_size = b_nd < 4 ? 1 : b->info.dim[b_nd - 4];
for (i = 0; i < b_nd - 4; i++)
b_batch_size *= b->info.dim[i];
assert(a_batch_size == b_batch_size || a_batch_size == 1);
Expand Down Expand Up @@ -256,12 +259,13 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
.A_trans = 0,
.B_trans = 1,
.D_trans = 0,
.batched = is_batched,
.fused_bias = (bias ? 1 : 0),

.batch_dims_a = { 0 },
.batch_dims_b = { 0 },
.batch_dims_d = { 0 },
.batch_dimension = b_batch_size,
.batch_stride_a = a_batch_size > 1 ? H * W * I_dim : 0,
.batch_stride_b = w_batch_size > 1 ? O * I_dim : 0,
.batch_stride_c = b_batch_size > 1 ? H * W * O : 0,
.batch_stride_d = 0,
};
} else {
params = (ccv_nnc_mfa_gemm_params_t){
Expand All @@ -272,56 +276,16 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
.A_trans = 0,
.B_trans = 0,
.D_trans = 1,
.batched = is_batched,
.fused_bias = (bias ? 1 : 0),

.batch_dims_a = { 0 },
.batch_dims_b = { 0 },
.batch_dims_d = { 0 },
.batch_dimension = b_batch_size,
.batch_stride_a = w_batch_size > 1 ? O * I_dim : 0,
.batch_stride_b = a_batch_size > 1 ? H * W * I_dim : 0,
.batch_stride_c = b_batch_size > 1 ? H * W * O : 0,
.batch_stride_d = 0,
};
}

if (is_batched) {
if (a->info.format == CCV_TENSOR_FORMAT_NHWC)
{
// Create a null-terminated list of batch dimensions.
int A_batch_dim = a_nd - 3;
for (int i = 0; i < A_batch_dim; ++i) {
params.batch_dims_a[i] = adim[i];
}
if (A_batch_dim < CCV_NNC_MAX_DIM_ALLOC) {
params.batch_dims_a[A_batch_dim] = 0;
}

int B_batch_dim = w_nd - 4;
for (int i = 0; i < B_batch_dim; ++i) {
params.batch_dims_b[i] = w->info.dim[i];
}
if (B_batch_dim < CCV_NNC_MAX_DIM_ALLOC) {
params.batch_dims_b[B_batch_dim] = 0;
}
} else {
// Create a null-terminated list of batch dimensions.
int B_batch_dim = w_nd - 4;
for (int i = 0; i < B_batch_dim; ++i) {
params.batch_dims_a[i] = w->info.dim[i];
}
if (B_batch_dim < CCV_NNC_MAX_DIM_ALLOC) {
params.batch_dims_a[B_batch_dim] = 0;
}

int A_batch_dim = a_nd - 3;
for (int i = 0; i < A_batch_dim; ++i) {
params.batch_dims_b[i] = adim[i];
}
if (A_batch_dim < CCV_NNC_MAX_DIM_ALLOC) {
params.batch_dims_b[A_batch_dim] = 0;
}
}

params.batch_dims_d[0] = 1;
params.batch_dims_d[1] = 0;
}
ccv_nnc_mfa_prepare_gemm(context, params);

mtl_command_batch_t* command_batch = ccv_nnc_stream_context_start_command_batch(stream_context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,13 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c
.A_trans = false,
.B_trans = true,
.D_trans = false,
.batched = 0,
.fused_bias = (bias ? 1 : 0),

.batch_dims_a = { 0 },
.batch_dims_b = { 0 },
.batch_dims_d = { 0 },
.batch_dimension = 1,
.batch_stride_a = 0,
.batch_stride_b = 0,
.batch_stride_c = 0,
.batch_stride_d = 0,
};
ccv_nnc_mfa_prepare_gemm(context, params);

Expand Down
Loading

0 comments on commit ddd3f97

Please sign in to comment.