Skip to content

Commit

Permalink
Added new GEMM MFA implementation that is optimized for M3 / M4 devices.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Aug 16, 2024
1 parent 6c30517 commit 5c5bc18
Show file tree
Hide file tree
Showing 27 changed files with 3,362 additions and 349 deletions.
523 changes: 523 additions & 0 deletions bin/nnc/laplacian_test.cpp

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion bin/nnc/makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ LDFLAGS := -L"../../lib" -lccv $(LDFLAGS)
CFLAGS := -O3 -Wall -I"../../lib" $(CFLAGS)
NVFLAGS := -O3 -I"../../lib" -lineinfo $(NVFLAGS)

TARGETS = nnc-e2e-verify nnc-e2e-sym-verify nnc-sym cifar-10 imagenet coco imdb iwslt wmt csv imdb_lstm
TARGETS = nnc-e2e-verify nnc-e2e-sym-verify nnc-sym cifar-10 imagenet coco imdb iwslt wmt csv imdb_lstm laplacian_test

FUZZ_TARGETS = csv_fuzz

Expand Down Expand Up @@ -37,6 +37,9 @@ libccv.a:
%.o: %.c
$(CC) $< -o $@ -c $(CFLAGS)

laplacian_test.o: laplacian_test.cpp
$(CC) $< -o $@ -c $(CFLAGS) -std=c++17

.gitignore:
echo $(TARGETS) | tr ' ' '\n' > .gitignore

Expand Down
2 changes: 1 addition & 1 deletion lib/configure
Original file line number Diff line number Diff line change
Expand Up @@ -4760,7 +4760,7 @@ if test "$mps_support" = yes; then
printf "%s\n" "yes" >&6; }
DEFINE_MACROS="$DEFINE_MACROS-D HAVE_MPS "
MKLDFLAGS="$MKLDFLAGS-framework MetalPerformanceShaders -framework MetalPerformanceShadersGraph -framework Foundation -framework Metal -lc++ "
MKLDFLAGS="$MKLDFLAGS-framework MetalPerformanceShaders -framework MetalPerformanceShadersGraph -framework Foundation -framework Metal -framework OpenCL -lc++ "
CUDA_SRCS=""
Expand Down
21 changes: 3 additions & 18 deletions lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,9 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint

ccv_nnc_mfa_context_t* context = ccv_nnc_default_mfa_context();
const int is_mfa_gemv = !is_batched && ((a_rows == 1 && is_transpose_w && (w_rows % 4) == 0) || (!is_transpose_a && w_cols == 1 && (a_cols % 4) == 0));
// v1 only supports the same precision of accumulator as the tensor.
int is_different_accumulator_precision = ((cmd.info.blas.flags & CCV_NNC_GEMM_32F) && a_datatype == CCV_16F) || ((cmd.info.blas.flags & CCV_NNC_GEMM_16F) && a_datatype == CCV_32F);
int is_upcast = ((cmd.info.blas.flags & CCV_NNC_GEMM_32F) && a_datatype == CCV_16F);
const int is_mfa_supported =
ccv_nnc_mfa_context_supported(context) && is_contiguous && is_same_dtype && is_supported_dtype && (!is_batched || is_mfa_compatible_batch) && !(ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION) && (is_mfa_gemv || (!(ccv_nnc_flags() & CCV_NNC_DISABLE_MFA_GEMM) && !is_different_accumulator_precision));
ccv_nnc_mfa_context_supported(context) && is_contiguous && is_same_dtype && is_supported_dtype && (!is_batched || is_mfa_compatible_batch) && !(ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION) && (is_mfa_gemv || !(ccv_nnc_flags() & CCV_NNC_DISABLE_MFA_GEMM));

size_t a_data_size = 0;
if (CCV_GET_DATA_TYPE(a->info.datatype) == CCV_QX)
Expand Down Expand Up @@ -364,11 +363,9 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
.A_trans = (is_transpose_a ? 1 : 0),
.B_trans = (is_transpose_w ? 1 : 0),
.D_trans = 0,
.alpha = (float)1.0,
.beta = (float)0.0,
.batched = is_batched,
.fused_activation_function = 0,
.fused_bias = (bias ? 1 : 0),
.register_float = (is_upcast ? 1 : 0),

.batch_dims_a = { 0 },
.batch_dims_b = { 0 },
Expand Down Expand Up @@ -795,10 +792,7 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
.A_trans = 1,
.B_trans = (is_transpose_w ? 1 : 0),
.D_trans = 0,
.alpha = (float)1.0,
.beta = (float)0.0,
.batched = is_batched,
.fused_activation_function = 0,
.fused_bias = 0,

.batch_dims_a = { 0 },
Expand Down Expand Up @@ -834,10 +828,7 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
.A_trans = 0,
.B_trans = (is_transpose_w ? 0 : 1),
.D_trans = 0,
.alpha = (float)1.0,
.beta = (float)0.0,
.batched = is_batched,
.fused_activation_function = 0,
.fused_bias = 0,

.batch_dims_a = { 0 },
Expand Down Expand Up @@ -881,10 +872,7 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
.A_trans = 1,
.B_trans = (is_transpose_a ? 1 : 0),
.D_trans = 0,
.alpha = (float)1.0,
.beta = (float)0.0,
.batched = is_batched,
.fused_activation_function = 0,
.fused_bias = 0,

.batch_dims_a = { 0 },
Expand Down Expand Up @@ -920,10 +908,7 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
.A_trans = (is_transpose_a ? 0 : 1),
.B_trans = 0,
.D_trans = 0,
.alpha = (float)1.0,
.beta = (float)0.0,
.batched = is_batched,
.fused_activation_function = 0,
.fused_bias = 0,

.batch_dims_a = { 0 },
Expand Down
6 changes: 0 additions & 6 deletions lib/nnc/cmd/convolution/mps/ccv_nnc_conv_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,7 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
.A_trans = 0,
.B_trans = 1,
.D_trans = 0,
.alpha = (float)1.0,
.beta = (float)0.0,
.batched = is_batched,
.fused_activation_function = 0,
.fused_bias = (bias ? 1 : 0),

.batch_dims_a = { 0 },
Expand All @@ -275,10 +272,7 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
.A_trans = 0,
.B_trans = 0,
.D_trans = 1,
.alpha = (float)1.0,
.beta = (float)0.0,
.batched = is_batched,
.fused_activation_function = 0,
.fused_bias = (bias ? 1 : 0),

.batch_dims_a = { 0 },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,7 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c
.A_trans = false,
.B_trans = true,
.D_trans = false,
.alpha = (float)1.0,
.beta = (float)0.0,
.batched = 0,
.fused_activation_function = 0,
.fused_bias = (bias ? 1 : 0),

.batch_dims_a = { 0 },
Expand Down
10 changes: 4 additions & 6 deletions lib/nnc/mfa/ccv_nnc_mfa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ mfa::context* ccv_nnc_init_mfa_context(MTL::Device* device) {
return new mfa::context(device);
}

void ccv_nnc_mfa_clear_pipeline_cache(ccv_nnc_mfa_context_t* context) {
context->v2_cache.evict();
}

void ccv_nnc_deinit_mfa_context(mfa::context* context) {
delete context;
}
Expand Down Expand Up @@ -86,12 +90,6 @@ void mfa::cache<mfa::attention::hash, mfa::attention::pipeline>::prepare(mfa::co
_mfa_cache_prepare(&map, context, hash);
}

template <>
void mfa::cache<mfa::gemm::hash, mfa::gemm::pipeline>::prepare(mfa::context* context, mfa::gemm::hash hash)
{
_mfa_cache_prepare(&map, context, hash);
}

template <>
void mfa::cache<mfa::normalization::hash, mfa::normalization::pipeline>::prepare(mfa::context* context, mfa::normalization::hash hash)
{
Expand Down
7 changes: 5 additions & 2 deletions lib/nnc/mfa/ccv_nnc_mfa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
#include "nnc/ccv_nnc.h"
#include "ccv_nnc_mfa_defines.hpp"
#include "ccv_nnc_mfa_attention.hpp"
#include "ccv_nnc_mfa_gemm.hpp"
#include "ccv_nnc_mfa_normalization.hpp"
#include "ccv_nnc_mfa_depalettize.hpp"
#include "ccv_nnc_mfa_adam.hpp"
#include "ccv_nnc_mfa_cmul.hpp"
#include "ccv_nnc_mfa_gemm.hpp"
#include "ccv_nnc_mfa_gemv.hpp"
#include "ccv_nnc_mfa_cast.hpp"
#include "ccv_nnc_mfa_add.hpp"
Expand All @@ -17,6 +17,7 @@
#include "nnc/mfa/3rdparty/metal-cpp/Dispatch.hpp"
#include "nnc/mfa/3rdparty/metal-cpp/Metal.hpp"
#include "ccv_nnc_mfa_error.hpp"
#include "v2/ShaderCache.hpp"
#include <unordered_map>

namespace ccv {
Expand Down Expand Up @@ -48,14 +49,15 @@ class context {
context(MTL::Device* device);

cache<attention::hash, attention::pipeline> attention_cache;
cache<gemm::hash, gemm::pipeline> gemm_cache;
cache<normalization::hash, normalization::pipeline> normalization_cache;
cache<depalettize::hash, depalettize::pipeline> depalettize_cache;
cache<adam::hash, adam::pipeline> adam_cache;
cache<cmul::hash, cmul::pipeline> cmul_cache;
cache<gemv::hash, gemv::pipeline> gemv_cache;
cache<cast::hash, cast::pipeline> cast_cache;
cache<add::hash, add::pipeline> add_cache;

ShaderCache v2_cache;

MTL::Buffer* request_scratch(uint64_t size);
};
Expand All @@ -68,6 +70,7 @@ extern "C" {
#endif // __cplusplus

ccv_nnc_mfa_context_t* ccv_nnc_init_mfa_context(mtl_device_t* context);
void ccv_nnc_mfa_clear_pipeline_cache(ccv_nnc_mfa_context_t* context);
void ccv_nnc_deinit_mfa_context(ccv_nnc_mfa_context_t* context);
uint8_t ccv_nnc_mfa_context_supported(ccv_nnc_mfa_context_t* context);
uint16_t ccv_nnc_mfa_context_log_level(ccv_nnc_mfa_context_t* context);
Expand Down
Loading

0 comments on commit 5c5bc18

Please sign in to comment.