Skip to content

Commit

Permalink
Avoid bitwise moving of CPP Module
Browse files Browse the repository at this point in the history
  • Loading branch information
barakugav committed Aug 16, 2024
1 parent 4f772f1 commit 7a5e21c
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
4 changes: 2 additions & 2 deletions executorch-sys/cpp/executorch_rs_ext/api_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,11 @@ namespace executorch_rs
}

#if defined(EXECUTORCH_RS_MODULE)
torch::executor::Module Module_new(torch::executor::ArrayRef<char> file_path, torch::executor::Module::MlockConfig mlock_config, torch::executor::EventTracer *event_tracer)
torch::executor::Module *Module_new(torch::executor::ArrayRef<char> file_path, torch::executor::Module::MlockConfig mlock_config, torch::executor::EventTracer *event_tracer)
{
std::string file_path_str(file_path.begin(), file_path.end());
std::unique_ptr<torch::executor::EventTracer> event_tracer2(event_tracer);
return torch::executor::Module(file_path_str, mlock_config, std::move(event_tracer2));
return new torch::executor::Module(file_path_str, mlock_config, std::move(event_tracer2));
}
void Module_destructor(torch::executor::Module *module_)
{
Expand Down
2 changes: 1 addition & 1 deletion executorch-sys/cpp/executorch_rs_ext/api_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ namespace executorch_rs
torch::executor::util::BufferDataLoader BufferDataLoader_new(const void *data, size_t size);

#if defined(EXECUTORCH_RS_MODULE)
torch::executor::Module Module_new(torch::executor::ArrayRef<char> file_path, torch::executor::Module::MlockConfig mlock_config, torch::executor::EventTracer *event_tracer);
torch::executor::Module *Module_new(torch::executor::ArrayRef<char> file_path, torch::executor::Module::MlockConfig mlock_config, torch::executor::EventTracer *event_tracer);
void Module_destructor(torch::executor::Module *module);
torch::executor::Result<Vec<Vec<char>>> Module_method_names(torch::executor::Module *module);
torch::executor::Error Module_load_method(torch::executor::Module *module, torch::executor::ArrayRef<char> method_name);
Expand Down
24 changes: 12 additions & 12 deletions src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//!
//! See the `hello_world_add` example for how to load and execute a module.

use core::ptr::NonNull;
use std::collections::HashSet;
use std::path::Path;
use std::ptr;
Expand All @@ -21,7 +22,7 @@ use crate::{et_c, et_rs_c};
/// A facade class for loading programs and executing methods within them.
///
/// See the `hello_world_add` example for how to load and execute a module.
pub struct Module(et_c::Module);
pub struct Module(NonNull<et_c::Module>);
impl Module {
/// Constructs an instance by loading a program from a file with specified
/// memory locking behavior.
Expand All @@ -43,7 +44,8 @@ impl Module {
let file_path = ArrayRef::from_slice(util::str2chars(file_path).unwrap());
let mlock_config = mlock_config.unwrap_or(MlockConfig::UseMlock);
let event_tracer = ptr::null_mut(); // TODO: support event tracer
Self(unsafe { et_rs_c::Module_new(file_path.0, mlock_config, event_tracer) })
let module = unsafe { et_rs_c::Module_new(file_path.0, mlock_config, event_tracer) };
Self(unsafe { NonNull::new_unchecked(module) })
}

/// Loads the program using the specified data loader and memory allocator.
Expand All @@ -58,7 +60,7 @@ impl Module {
/// An Error to indicate success or failure of the loading process.
pub fn load(&mut self, verification: Option<ProgramVerification>) -> Result<()> {
let verification = verification.unwrap_or(ProgramVerification::Minimal);
unsafe { et_c::Module_load(&mut self.0, verification) }.rs()
unsafe { et_c::Module_load(self.0.as_mut(), verification) }.rs()
}

/// Checks if the program is loaded.
Expand All @@ -67,7 +69,7 @@ impl Module {
///
/// true if the program is loaded, false otherwise.
pub fn is_loaded(&self) -> bool {
unsafe { et_c::Module_is_loaded(&self.0) }
unsafe { et_c::Module_is_loaded(self.0.as_ref()) }
}

/// Get a list of method names available in the loaded program.
Expand All @@ -77,7 +79,7 @@ impl Module {
///
/// A set of strings containing the names of the methods, or an error if the program or method failed to load.
pub fn method_names(&mut self) -> Result<HashSet<String>> {
let names = unsafe { et_rs_c::Module_method_names(&mut self.0) }
let names = unsafe { et_rs_c::Module_method_names(self.0.as_mut()) }
.rs()?
.rs();
Ok(names
Expand All @@ -103,7 +105,7 @@ impl Module {
/// If the method name is not a valid UTF-8 string or contains a null character.
pub fn load_method(&mut self, method_name: &str) -> Result<()> {
let method_name = ArrayRef::from_slice(util::str2chars(method_name).unwrap());
unsafe { et_rs_c::Module_load_method(&mut self.0, method_name.0) }.rs()
unsafe { et_rs_c::Module_load_method(self.0.as_mut(), method_name.0) }.rs()
}

/// Checks if a specific method is loaded.
Expand All @@ -121,7 +123,7 @@ impl Module {
/// If the method name is not a valid UTF-8 string or contains a null character.
pub fn is_method_loaded(&self, method_name: &str) -> bool {
let method_name = ArrayRef::from_slice(util::str2chars(method_name).unwrap());
unsafe { et_rs_c::Module_is_method_loaded(&self.0, method_name.0) }
unsafe { et_rs_c::Module_is_method_loaded(self.0.as_ref(), method_name.0) }
}

/// Get a method metadata struct by method name.
Expand All @@ -140,9 +142,7 @@ impl Module {
/// If the method name is not a valid UTF-8 string or contains a null character.
pub fn method_meta<'a>(&'a self, method_name: &str) -> Result<MethodMeta<'a>> {
let method_name = ArrayRef::from_slice(util::str2chars(method_name).unwrap());
let meta =
unsafe { et_rs_c::Module_method_meta(&self.0 as *const _ as *mut _, method_name.0) }
.rs()?;
let meta = unsafe { et_rs_c::Module_method_meta(self.0.as_ptr(), method_name.0) }.rs()?;
Ok(unsafe { MethodMeta::new(meta) })
}

Expand Down Expand Up @@ -171,7 +171,7 @@ impl Module {
let inputs = unsafe { std::mem::transmute::<&[EValue], &[et_c::EValue]>(inputs) };
let inputs = ArrayRef::from_slice(inputs);
let outputs =
unsafe { et_rs_c::Module_execute(&mut self.0, method_name.0, inputs.0) }.rs()?;
unsafe { et_rs_c::Module_execute(self.0.as_mut(), method_name.0, inputs.0) }.rs()?;
// Safety: The transmute is safe because the memory layout of EValue and et_c::EValue is the same.
let outputs = unsafe {
std::mem::transmute::<et_rs_c::Vec<et_c::EValue>, et_rs_c::Vec<EValue<'a>>>(outputs)
Expand All @@ -195,7 +195,7 @@ impl Module {
}
impl Drop for Module {
fn drop(&mut self) {
unsafe { et_rs_c::Module_destructor(&mut self.0) };
unsafe { et_rs_c::Module_destructor(self.0.as_mut()) };
}
}

Expand Down

0 comments on commit 7a5e21c

Please sign in to comment.