diff --git a/executorch-sys/cpp/executorch_rs_ext/api_utils.cpp b/executorch-sys/cpp/executorch_rs_ext/api_utils.cpp index 982931e..335ed51 100644 --- a/executorch-sys/cpp/executorch_rs_ext/api_utils.cpp +++ b/executorch-sys/cpp/executorch_rs_ext/api_utils.cpp @@ -174,11 +174,11 @@ namespace executorch_rs } #if defined(EXECUTORCH_RS_MODULE) - torch::executor::Module Module_new(torch::executor::ArrayRef file_path, torch::executor::Module::MlockConfig mlock_config, torch::executor::EventTracer *event_tracer) + torch::executor::Module *Module_new(torch::executor::ArrayRef file_path, torch::executor::Module::MlockConfig mlock_config, torch::executor::EventTracer *event_tracer) { std::string file_path_str(file_path.begin(), file_path.end()); std::unique_ptr event_tracer2(event_tracer); - return torch::executor::Module(file_path_str, mlock_config, std::move(event_tracer2)); + return new torch::executor::Module(file_path_str, mlock_config, std::move(event_tracer2)); } void Module_destructor(torch::executor::Module *module_) { diff --git a/executorch-sys/cpp/executorch_rs_ext/api_utils.hpp b/executorch-sys/cpp/executorch_rs_ext/api_utils.hpp index 53f9242..c9fd266 100644 --- a/executorch-sys/cpp/executorch_rs_ext/api_utils.hpp +++ b/executorch-sys/cpp/executorch_rs_ext/api_utils.hpp @@ -102,7 +102,7 @@ namespace executorch_rs torch::executor::util::BufferDataLoader BufferDataLoader_new(const void *data, size_t size); #if defined(EXECUTORCH_RS_MODULE) - torch::executor::Module Module_new(torch::executor::ArrayRef file_path, torch::executor::Module::MlockConfig mlock_config, torch::executor::EventTracer *event_tracer); + torch::executor::Module *Module_new(torch::executor::ArrayRef file_path, torch::executor::Module::MlockConfig mlock_config, torch::executor::EventTracer *event_tracer); void Module_destructor(torch::executor::Module *module); torch::executor::Result>> Module_method_names(torch::executor::Module *module); torch::executor::Error Module_load_method(torch::executor::Module *module, torch::executor::ArrayRef method_name); diff --git a/src/module.rs b/src/module.rs index 011e570..54de535 100644 --- a/src/module.rs +++ b/src/module.rs @@ -8,6 +8,7 @@ //! //! See the `hello_world_add` example for how to load and execute a module. +use core::ptr::NonNull; use std::collections::HashSet; use std::path::Path; use std::ptr; @@ -21,7 +22,7 @@ use crate::{et_c, et_rs_c}; /// A facade class for loading programs and executing methods within them. /// /// See the `hello_world_add` example for how to load and execute a module. -pub struct Module(et_c::Module); +pub struct Module(NonNull); impl Module { /// Constructs an instance by loading a program from a file with specified /// memory locking behavior. @@ -43,7 +44,8 @@ impl Module { let file_path = ArrayRef::from_slice(util::str2chars(file_path).unwrap()); let mlock_config = mlock_config.unwrap_or(MlockConfig::UseMlock); let event_tracer = ptr::null_mut(); // TODO: support event tracer - Self(unsafe { et_rs_c::Module_new(file_path.0, mlock_config, event_tracer) }) + let module = unsafe { et_rs_c::Module_new(file_path.0, mlock_config, event_tracer) }; + Self(unsafe { NonNull::new_unchecked(module) }) } /// Loads the program using the specified data loader and memory allocator. @@ -58,7 +60,7 @@ impl Module { /// An Error to indicate success or failure of the loading process. pub fn load(&mut self, verification: Option) -> Result<()> { let verification = verification.unwrap_or(ProgramVerification::Minimal); - unsafe { et_c::Module_load(&mut self.0, verification) }.rs() + unsafe { et_c::Module_load(self.0.as_mut(), verification) }.rs() } /// Checks if the program is loaded. @@ -67,7 +69,7 @@ impl Module { /// /// true if the program is loaded, false otherwise. pub fn is_loaded(&self) -> bool { - unsafe { et_c::Module_is_loaded(&self.0) } + unsafe { et_c::Module_is_loaded(self.0.as_ref()) } } /// Get a list of method names available in the loaded program. @@ -77,7 +79,7 @@ impl Module { /// /// A set of strings containing the names of the methods, or an error if the program or method failed to load. pub fn method_names(&mut self) -> Result> { - let names = unsafe { et_rs_c::Module_method_names(&mut self.0) } + let names = unsafe { et_rs_c::Module_method_names(self.0.as_mut()) } .rs()? .rs(); Ok(names @@ -103,7 +105,7 @@ impl Module { /// If the method name is not a valid UTF-8 string or contains a null character. pub fn load_method(&mut self, method_name: &str) -> Result<()> { let method_name = ArrayRef::from_slice(util::str2chars(method_name).unwrap()); - unsafe { et_rs_c::Module_load_method(&mut self.0, method_name.0) }.rs() + unsafe { et_rs_c::Module_load_method(self.0.as_mut(), method_name.0) }.rs() } /// Checks if a specific method is loaded. @@ -121,7 +123,7 @@ impl Module { /// If the method name is not a valid UTF-8 string or contains a null character. pub fn is_method_loaded(&self, method_name: &str) -> bool { let method_name = ArrayRef::from_slice(util::str2chars(method_name).unwrap()); - unsafe { et_rs_c::Module_is_method_loaded(&self.0, method_name.0) } + unsafe { et_rs_c::Module_is_method_loaded(self.0.as_ref(), method_name.0) } } /// Get a method metadata struct by method name. @@ -140,9 +142,7 @@ impl Module { /// If the method name is not a valid UTF-8 string or contains a null character. pub fn method_meta<'a>(&'a self, method_name: &str) -> Result> { let method_name = ArrayRef::from_slice(util::str2chars(method_name).unwrap()); - let meta = - unsafe { et_rs_c::Module_method_meta(&self.0 as *const _ as *mut _, method_name.0) } - .rs()?; + let meta = unsafe { et_rs_c::Module_method_meta(self.0.as_ptr(), method_name.0) }.rs()?; Ok(unsafe { MethodMeta::new(meta) }) } @@ -171,7 +171,7 @@ impl Module { let inputs = unsafe { std::mem::transmute::<&[EValue], &[et_c::EValue]>(inputs) }; let inputs = ArrayRef::from_slice(inputs); let outputs = - unsafe { et_rs_c::Module_execute(&mut self.0, method_name.0, inputs.0) }.rs()?; + unsafe { et_rs_c::Module_execute(self.0.as_mut(), method_name.0, inputs.0) }.rs()?; // Safety: The transmute is safe because the memory layout of EValue and et_c::EValue is the same. let outputs = unsafe { std::mem::transmute::, et_rs_c::Vec>>(outputs) @@ -195,7 +195,7 @@ impl Module { } impl Drop for Module { fn drop(&mut self) { - unsafe { et_rs_c::Module_destructor(&mut self.0) }; + unsafe { et_rs_c::Module_destructor(self.0.as_mut()) }; } }