Skip to content

Commit

Permalink
Default to accumulator at FP32.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Aug 18, 2024
1 parent 76d2b6c commit 88ef7bc
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint

ccv_nnc_mfa_context_t* context = ccv_nnc_default_mfa_context();
const int is_mfa_gemv = !is_batched && ((a_rows == 1 && is_transpose_w && (w_rows % 4) == 0) || (!is_transpose_a && w_cols == 1 && (a_cols % 4) == 0));
int is_upcast = ((cmd.info.blas.flags & CCV_NNC_GEMM_32F) && a_datatype == CCV_16F);
int is_downcast = ((cmd.info.blas.flags & CCV_NNC_GEMM_16F) && a_datatype == CCV_16F);
const int is_mfa_supported =
ccv_nnc_mfa_context_supported(context) && is_contiguous && is_same_dtype && is_supported_dtype && (!is_batched || is_mfa_compatible_batch) && !(ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION) && (is_mfa_gemv || !(ccv_nnc_flags() & CCV_NNC_DISABLE_MFA_GEMM));

Expand Down Expand Up @@ -385,7 +385,7 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
.B_trans = (is_transpose_w ? 1 : 0),
.D_trans = 0,
.fused_bias = (bias ? 1 : 0),
.register_float = (is_upcast ? 1 : 0),
.register_float = (is_downcast ? 0 : 1),

.batch_dimension = b_batch_size,
.batch_stride_a = a_batch_size > 1 ? ccv_max(a_batch_stride, b_rows * w_rows) : 0,
Expand Down

0 comments on commit 88ef7bc

Please sign in to comment.