diff --git a/lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m b/lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m index 629c92673..55e1134ec 100644 --- a/lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m +++ b/lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m @@ -196,7 +196,7 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint ccv_nnc_mfa_context_t* context = ccv_nnc_default_mfa_context(); const int is_mfa_gemv = !is_batched && ((a_rows == 1 && is_transpose_w && (w_rows % 4) == 0) || (!is_transpose_a && w_cols == 1 && (a_cols % 4) == 0)); - int is_upcast = ((cmd.info.blas.flags & CCV_NNC_GEMM_32F) && a_datatype == CCV_16F); + int is_downcast = ((cmd.info.blas.flags & CCV_NNC_GEMM_16F) && a_datatype == CCV_16F); const int is_mfa_supported = ccv_nnc_mfa_context_supported(context) && is_contiguous && is_same_dtype && is_supported_dtype && (!is_batched || is_mfa_compatible_batch) && !(ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION) && (is_mfa_gemv || !(ccv_nnc_flags() & CCV_NNC_DISABLE_MFA_GEMM)); @@ -385,7 +385,7 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint .B_trans = (is_transpose_w ? 1 : 0), .D_trans = 0, .fused_bias = (bias ? 1 : 0), - .register_float = (is_upcast ? 1 : 0), + .register_float = (is_downcast ? 0 : 1), .batch_dimension = b_batch_size, .batch_stride_a = a_batch_size > 1 ? ccv_max(a_batch_stride, b_rows * w_rows) : 0,