Skip to content

Commit

Permalink
Pull in MFA GEMM (#244)
Browse files Browse the repository at this point in the history
* Remove .DS_Store

* Initial code for loading MFA

* Fix error messages

* Attempt to compile C++ code

* Remove unused dependency

* [PATCH] Fix cpp compilation errors.

* Fix error output

* Refactor the types

* Dispatch C++ bindings

* Add hasher

* Refactor MFA errors

* Skeleton implementation of 'encode_gemm'

* Calculate threadgroup memory and grid XY

* Finish draft of encoding code

* MTL::CommandBatch

* Detect compatible MFA GEMMs

* Add bias restriction

* Fix up MTL::CommandBatch

* Fix typo

* 48x48 is the most appropriate default for FP32

* It is running

* Introduce max stream count when schedule a graph.

* Avoid stack allocation if it is too large.

* Don't allocate gradients if it doesn't compute.

* No gradients are allocated if these are not trainable.

* Fix up the map

* Document the block sizes

* Block size selection heuristic

* Preparation for batching

* Add code for encoding the matrix offsets in a batch

* Support a subset of batching

* Caught a bug in the dispatching code

* Move files around to make compilation happy.

* Regenerate configure file.

* Add unordered map import for ccv_nnc_mfa.hpp

* Style updates for ccv_nnc_gemm_mps.

* Revert an update to lib/ccv.h

* Move 2-space to tab.

* Minor change to force init mps.

* Commit the lib file and pass in the path.

* Gate with OSX, MAC is 1 on both iOS and Mac.

* call useResource on the buffers.

* Update lib/nnc/mfa/ccv_nnc_mfa_gemm.cpp

Co-authored-by: Philip Turner <[email protected]>

---------

Co-authored-by: Philip Turner <[email protected]>
  • Loading branch information
liuliu and philipturner authored Jul 24, 2023
1 parent e5a13c0 commit bbd3810
Show file tree
Hide file tree
Showing 51 changed files with 3,201 additions and 1,239 deletions.
27 changes: 26 additions & 1 deletion lib/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ cc_library(
copts = ccv_default_copts()
)

cc_library(
name = "metal_cpp_hdrs",
srcs = [
"nnc/mfa/3rdparty/metal-cpp/Dispatch.cpp",
],
hdrs = [
"nnc/mfa/3rdparty/metal-cpp/Dispatch.hpp",
"nnc/mfa/3rdparty/metal-cpp/Metal.hpp",
],
)

cc_library(
name = "siphash",
srcs = [
Expand Down Expand Up @@ -359,6 +370,17 @@ cuda_library(
]
)

cc_library(
name = "nnc_mfa_compat",
srcs = glob(["nnc/mfa/**/*.cpp"]),
hdrs = glob(["nnc/mfa/**/*.hpp"]),
copts = ccv_default_copts(),
deps = [
":metal_cpp_hdrs",
":nnc_headers"
]
)

objc_library(
name = "nnc_mps_compat",
non_arc_srcs = [
Expand All @@ -370,8 +392,9 @@ objc_library(
copts = ccv_default_copts(),
sdk_frameworks = ["Metal", "MetalPerformanceShaders", "MetalPerformanceShadersGraph"],
deps = [
":nnc_mfa_compat",
":nnc_headers",
":SFMT_hdrs",
":SFMT_hdrs"
]
)

Expand Down Expand Up @@ -417,6 +440,7 @@ objc_library(
copts = ccv_default_copts(),
deps = [
":nnc_headers",
":nnc_mfa_compat",
":nnc_mps_compat",
]
)
Expand Down Expand Up @@ -516,6 +540,7 @@ cc_library(
"//conditions:default": []
}) + select({
"//config:have_mps": [
":nnc_mfa_compat",
":nnc_mps_compat",
":cmd_mps"
],
Expand Down
1 change: 1 addition & 0 deletions lib/config.mk.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ NVCC := @NVCC@
CUDA_SRCS := @CUDA_SRCS@
CUDA_COMPAT_LIB := @CUDA_COMPAT_LIB@
CUDA_CMD_LIB := @CUDA_CMD_LIB@
MFA_COMPAT_LIB := @MFA_COMPAT_LIB@
MPS_COMPAT_LIB := @MPS_COMPAT_LIB@
MPS_CMD_LIB := @MPS_CMD_LIB@
DEFINE_MACROS := @DEFINE_MACROS@
Expand Down
Loading

0 comments on commit bbd3810

Please sign in to comment.