Skip to content

Commit

Permalink
Commit the lib file and pass in the path.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Jul 22, 2023
1 parent bc0727f commit 9d8749b
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 22 deletions.
Binary file added lib/nnc/mfa/3rdparty/libmfaios16-0.2.metallib
Binary file not shown.
Binary file added lib/nnc/mfa/3rdparty/libmfamacos13-0.2.metallib
Binary file not shown.
20 changes: 5 additions & 15 deletions lib/nnc/mfa/ccv_nnc_mfa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ using namespace ccv::nnc;

// MARK: - C

mfa::context* ccv_nnc_init_mfa_context(MTL::Device* device) {
return new mfa::context(device);
mfa::context* ccv_nnc_init_mfa_context(MTL::Device* device, const char* metallib_path) {
return new mfa::context(device, metallib_path);
}

void ccv_nnc_deinit_mfa_context(mfa::context* context) {
Expand Down Expand Up @@ -82,7 +82,7 @@ void mfa::cache<mfa::gemm::hash, mfa::gemm::pipeline>::prepare(mfa::context* con
_mfa_cache_prepare(&map, context, hash, async);
}

mfa::context::context(MTL::Device* device)
mfa::context::context(MTL::Device* device, const char* metallib_path)
{
auto* pool = NS::AutoreleasePool::alloc()->init();

Expand All @@ -101,19 +101,9 @@ mfa::context::context(MTL::Device* device)
// Example: /usr/local/MetalFlashAttention/lib/libMetalFlashAttention.metallib
// We need to have two different variants based on the operating system. macOS
// will not accept a metallib compiled for iOS/tvOS/visionOS and vice versa.
const char* metallib_path = getenv("CCV_NNC_MFA_METALLIB_PATH");
if (!metallib_path) {
// If a metallib was bundled with the Bazel build, you can hard-code the
// metallib's path into the source code. Choose this path if the user hasn't
// already set the `CCV_NNC_MFA_METALLIB_PATH` environment variable.
constexpr const char* bundled_path = nullptr;

if (bundled_path) {
metallib_path = bundled_path;
} else {
this->supported = false;
return;
}
this->supported = false;
return;
}
if (METAL_LOG_LEVEL(this) >= 1) {
std::cerr << METAL_LOG_HEADER << "Started loading 'libMetalFlashAttention.metallib'." << std::endl;
Expand Down
4 changes: 2 additions & 2 deletions lib/nnc/mfa/ccv_nnc_mfa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class context {
NS::SharedPtr<MTL::Device> device;
NS::SharedPtr<MTL::Library> library;

context(MTL::Device* device);
context(MTL::Device* device, const char* metallib_path);

// MFA keeps internal caches of pipeline state objects. If you're eagerly
// executing a command, call `sync_prepare_*` just before encoding it. This
Expand All @@ -57,7 +57,7 @@ class context {
extern "C" {
#endif // __cplusplus

ccv_nnc_mfa_context_t* ccv_nnc_init_mfa_context(mtl_device_t* context);
ccv_nnc_mfa_context_t* ccv_nnc_init_mfa_context(mtl_device_t* context, const char* metallib_path);
void ccv_nnc_deinit_mfa_context(ccv_nnc_mfa_context_t* context);
uint8_t ccv_nnc_mfa_context_supported(ccv_nnc_mfa_context_t* context);
uint16_t ccv_nnc_mfa_context_log_level(ccv_nnc_mfa_context_t* context);
Expand Down
23 changes: 18 additions & 5 deletions lib/nnc/mps/ccv_nnc_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <string.h>
#import <CoreFoundation/CoreFoundation.h>
#import <Foundation/Foundation.h>
#import <TargetConditionals.h>
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
#import <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
#import <objc/runtime.h>
Expand All @@ -24,12 +25,28 @@
return device;
}

@interface MTLFileBackedBuffer: NSObject
@property (nonatomic, copy) NSString* path;
@property (nonatomic, assign) NSUInteger size;
@end

ccv_nnc_mfa_context_t* ccv_nnc_default_mfa_context(void)
{
static dispatch_once_t once;
static ccv_nnc_mfa_context_t* context;
dispatch_once(&once, ^{
context = ccv_nnc_init_mfa_context((__bridge mtl_device_t*)ccv_nnc_default_device());
const char* metallib_path = getenv("CCV_NNC_MFA_METALLIB_PATH");
if (metallib_path)
context = ccv_nnc_init_mfa_context((__bridge mtl_device_t*)ccv_nnc_default_device(), metallib_path);
else {
NSBundle* bundle = [NSBundle bundleForClass:[MTLFileBackedBuffer class]];
#if TARGET_OS_IPHONE || TARGET_OS_MACCATALYST
NSString* path = [bundle pathForResource:@"libmfaios16-0.2" ofType:@"metallib"];
#else
NSString* path = [bundle pathForResource:@"libmfamacos13-0.2" ofType:@"metallib"];
#endif
context = ccv_nnc_init_mfa_context((__bridge mtl_device_t*)ccv_nnc_default_device(), path.UTF8String);
}
});
return context;
}
Expand Down Expand Up @@ -206,10 +223,6 @@ void mpobjfree(int device, void* ptr)
return buffer;
}

@interface MTLFileBackedBuffer: NSObject
@property (nonatomic, copy) NSString* path;
@property (nonatomic, assign) NSUInteger size;
@end
@implementation MTLFileBackedBuffer
@end

Expand Down

0 comments on commit 9d8749b

Please sign in to comment.