diff --git a/src/closure.rs b/src/closure.rs index 644eaf1..e3caaea 100644 --- a/src/closure.rs +++ b/src/closure.rs @@ -1,9 +1,8 @@ // Copyright (c) Viable Systems and TezEdge Contributors // SPDX-License-Identifier: MIT -use crate::error::OCamlException; use crate::mlvalues::tag; -use crate::mlvalues::{extract_exception, is_exception_result, tag_val, RawOCaml}; +use crate::mlvalues::{tag_val, RawOCaml}; use crate::value::OCaml; use crate::{OCamlRef, OCamlRuntime}; use ocaml_sys::{ @@ -70,11 +69,9 @@ impl OCamlClosure { cr: &'a mut OCamlRuntime, result: RawOCaml, ) -> OCaml<'a, R> { - if is_exception_result(result) { - let ex = unsafe { OCamlException::of(extract_exception(result)) }; - panic!("OCaml exception, message: {:?}", ex.message()) - } else { - unsafe { OCaml::new(cr, result) } + match unsafe { OCaml::of_exception_result(cr, result) } { + Some(ex) => panic!("OCaml exception, message: {:?}", ex.message()), + None => unsafe { OCaml::new(cr, result) }, } } } diff --git a/src/error.rs b/src/error.rs index 1af1cf4..d191659 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,16 +1,8 @@ // Copyright (c) Viable Systems and TezEdge Contributors // SPDX-License-Identifier: MIT -use crate::mlvalues::{is_block, string_val, tag_val, RawOCaml}; -use crate::mlvalues::{tag, MAX_FIXNUM, MIN_FIXNUM}; -use core::{fmt, slice}; -use ocaml_sys::caml_string_length; - -/// An OCaml exception value. -#[derive(Debug)] -pub struct OCamlException { - raw: RawOCaml, -} +use crate::mlvalues::{MAX_FIXNUM, MIN_FIXNUM}; +use core::fmt; #[derive(Debug)] pub enum OCamlFixnumConversionError { @@ -34,30 +26,3 @@ impl fmt::Display for OCamlFixnumConversionError { } } } - -impl OCamlException { - #[doc(hidden)] - pub unsafe fn of(raw: RawOCaml) -> Self { - OCamlException { raw } - } - - pub fn message(&self) -> Option { - if is_block(self.raw) { - unsafe { - let message = *(self.raw as *const RawOCaml).add(1); - - if is_block(message) && tag_val(message) == tag::STRING { - let error_message = - slice::from_raw_parts(string_val(message), caml_string_length(message)) - .to_owned(); - let error_message = String::from_utf8_unchecked(error_message); - Some(error_message) - } else { - None - } - } - } else { - None - } - } -} diff --git a/src/lib.rs b/src/lib.rs index de6da19..b6ccdd1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -296,12 +296,12 @@ pub use crate::boxroot::BoxRoot; pub use crate::closure::{OCamlFn1, OCamlFn2, OCamlFn3, OCamlFn4, OCamlFn5}; pub use crate::conv::{FromOCaml, ToOCaml}; -pub use crate::error::OCamlException; pub use crate::memory::alloc_cons as cons; pub use crate::memory::OCamlRef; +pub use crate::memory::{alloc_error, alloc_ok}; pub use crate::mlvalues::{ - bigarray, DynBox, OCamlBytes, OCamlFloat, OCamlFloatArray, OCamlInt, OCamlInt32, OCamlInt64, - OCamlList, OCamlUniformArray, RawOCaml, + bigarray, DynBox, OCamlBytes, OCamlException, OCamlFloat, OCamlFloatArray, OCamlInt, + OCamlInt32, OCamlInt64, OCamlList, OCamlUniformArray, RawOCaml, }; pub use crate::runtime::OCamlRuntime; pub use crate::value::OCaml; diff --git a/src/mlvalues.rs b/src/mlvalues.rs index 160000e..eeb6b78 100644 --- a/src/mlvalues.rs +++ b/src/mlvalues.rs @@ -62,3 +62,6 @@ pub struct OCamlInt64 {} /// [`OCaml`]`` is a reference to an OCaml `float` (boxed `float`) value. pub struct OCamlFloat {} + +/// [`OCaml`]`` is a reference to an OCaml `exn` value. +pub struct OCamlException {} diff --git a/src/mlvalues/bigarray.rs b/src/mlvalues/bigarray.rs index 4ac1a89..629b46a 100644 --- a/src/mlvalues/bigarray.rs +++ b/src/mlvalues/bigarray.rs @@ -5,7 +5,7 @@ use core::marker::PhantomData; /// # Safety /// /// This is unsafe to implement, because it allows casts -/// to the implementing type (through OCaml>::as_slice()). +/// to the implementing type (through `OCaml>::as_slice()`). /// /// To make this safe, the type implementing this trait must be /// safe to transmute from OCaml data with the relevant KIND. diff --git a/src/mlvalues/tag.rs b/src/mlvalues/tag.rs index c27db15..551dbba 100644 --- a/src/mlvalues/tag.rs +++ b/src/mlvalues/tag.rs @@ -6,5 +6,10 @@ pub use ocaml_sys::{ }; pub const TAG_POLYMORPHIC_VARIANT: Tag = 0; + +/// Note that `TAG_EXCEPTION`` is equivalent to `TAG_POLYMORPHIC_VARIANT`, and also +/// corresponds to the tag associated with records and tuples. +pub const TAG_EXCEPTION: Tag = 0; + pub const TAG_OK: Tag = 0; pub const TAG_ERROR: Tag = 1; diff --git a/src/value.rs b/src/value.rs index 24ba605..4ff7d61 100644 --- a/src/value.rs +++ b/src/value.rs @@ -6,7 +6,7 @@ use crate::{ error::OCamlFixnumConversionError, memory::{alloc_box, OCamlCell}, mlvalues::*, - FromOCaml, OCamlException, OCamlRef, OCamlRuntime, + FromOCaml, OCamlRef, OCamlRuntime, }; use core::any::Any; use core::borrow::Borrow; @@ -507,6 +507,37 @@ impl<'a, A: bigarray::BigarrayElt> OCaml<'a, bigarray::Array1> { } } +impl<'a> OCaml<'a, OCamlException> { + #[doc(hidden)] + pub unsafe fn of_exception_result( + cr: &'a OCamlRuntime, + exception_result: RawOCaml, + ) -> Option> { + if is_exception_result(exception_result) { + Some(OCaml::new(cr, extract_exception(exception_result))) + } else { + None + } + } + + /// If the exception has a single argument of type string, extracts and + /// returns it. Examples of such exceptions are `Failure of string` + /// (raised via the `failwith` OCaml function, or the + /// `caml_raise_with_string` C function) or `Invalid_argument of string`. + pub fn message(&self) -> Option { + if self.is_block_sized(2) && self.tag_value() == tag::TAG_EXCEPTION { + let exn_argument: OCaml = unsafe { self.field(1) }; + if exn_argument.is_block() && exn_argument.tag_value() == tag::STRING { + Some(exn_argument.to_rust()) + } else { + None + } + } else { + None + } + } +} + // Functions pub enum RefOrRooted<'a, 'b, T: 'static> { @@ -557,18 +588,16 @@ macro_rules! try_call_impl { &self, cr: &'c mut OCamlRuntime, $($argname: $argname),+ - ) -> Result, OCamlException> + ) -> Result, OCaml<'c, OCamlException>> where $($argname: OCamlParam<'a, 'b, $rt, $ot>),+ { $(let $argname = $argname.to_rooted(cr);)* let result = unsafe { $method(self.get_raw(), $($argname.get_raw()),+) }; - if is_exception_result(result) { - let ex = unsafe { OCamlException::of(extract_exception(result)) }; - Err(ex) - } else { - Ok(unsafe { OCaml::new(cr, result) }) + match unsafe { OCaml::of_exception_result(cr, result) } { + Some(ex) => Err(ex), + None => Ok(unsafe { OCaml::new(cr, result) }) } } } @@ -582,7 +611,7 @@ macro_rules! try_call_impl { &self, cr: &'c mut OCamlRuntime, $($argname2: $argname2),* - ) -> Result, OCamlException> + ) -> Result, OCaml<'c, OCamlException>> where $($argname2: OCamlParam<'a, 'b, $rt2, $ot2>),* { @@ -593,11 +622,9 @@ macro_rules! try_call_impl { }; let result = unsafe { caml_callbackN_exn(self.get_raw(), args.len(), args.as_mut_ptr()) }; - if is_exception_result(result) { - let ex = unsafe { OCamlException::of(extract_exception(result)) }; - Err(ex) - } else { - Ok(unsafe { OCaml::new(cr, result) }) + match unsafe { OCaml::of_exception_result(cr, result) } { + Some(ex) => Err(ex), + None => Ok(unsafe { OCaml::new(cr, result) }) } } } diff --git a/testing/ocaml-caller/ocaml_rust_caller.ml b/testing/ocaml-caller/ocaml_rust_caller.ml index 15a4ee2..c92f006 100644 --- a/testing/ocaml-caller/ocaml_rust_caller.ml +++ b/testing/ocaml-caller/ocaml_rust_caller.ml @@ -60,6 +60,12 @@ module Rust = struct external string_of_polymorphic_movement : movement_polymorphic -> string = "rust_string_of_polymorphic_movement" + external call_ocaml_closure : (int -> int) -> (int, string) result + = "rust_call_ocaml_closure" + + external call_ocaml_closure_and_return_exn : (int -> int) -> (int, exn) result + = "rust_call_ocaml_closure_and_return_exn" + external rust_rust_add_7ints : int -> int -> int -> int -> int -> int -> int -> int = "rust_rust_add_7ints_byte" "rust_rust_add_7ints" @@ -178,6 +184,33 @@ let test_interpret_polymorphic_movement () = Alcotest.(check (list string)) "Interpret a polymorphic variant" expected result +let test_call_ocaml_closure () = + let expected = [ Ok 1; Error "some error message"; Error "no message" ] in + let result = + [ + Rust.call_ocaml_closure (fun x -> x + 1); + Rust.call_ocaml_closure (fun _ -> failwith "some error message"); + Rust.call_ocaml_closure (fun _ -> raise Not_found); + ] + in + Alcotest.(check (list (result int string))) "Call a closure" expected result + +let test_call_ocaml_closure_and_return_exn () = + let expected = + [ Ok 1; Error (Failure "some error message"); Error Not_found ] + in + let result = + [ + Rust.call_ocaml_closure_and_return_exn (fun x -> x + 1); + Rust.call_ocaml_closure_and_return_exn (fun _ -> + failwith "some error message"); + Rust.call_ocaml_closure_and_return_exn (fun _ -> raise Not_found); + ] + in + let exn = Alcotest.of_pp Base.Exn.pp in + Alcotest.(check (list (result int exn))) + "Call a closure and return exn" expected result + let test_byte_function () = let expected = 1 + 2 + 3 + 4 + 5 + 6 + 7 in let result = Rust.rust_rust_add_7ints 1 2 3 4 5 6 7 in @@ -238,6 +271,9 @@ let () = test_case "Rust.string_of_movement" `Quick test_interpret_movement; test_case "Rust.string_of_polymorphic_movement" `Quick test_interpret_polymorphic_movement; + test_case "Rust.call_ocaml_closure" `Quick test_call_ocaml_closure; + test_case "Rust.call_ocaml_closure_and_return_exn" `Quick + test_call_ocaml_closure_and_return_exn; test_case "Rust.rust_rust_add_7ints" `Quick test_byte_function; ] ); ]; diff --git a/testing/ocaml-caller/rust/src/lib.rs b/testing/ocaml-caller/rust/src/lib.rs index 94dddde..322d2e0 100644 --- a/testing/ocaml-caller/rust/src/lib.rs +++ b/testing/ocaml-caller/rust/src/lib.rs @@ -2,9 +2,9 @@ // SPDX-License-Identifier: MIT use ocaml_interop::{ - ocaml_export, ocaml_unpack_polymorphic_variant, ocaml_unpack_variant, OCaml, OCamlBytes, - OCamlFloat, OCamlFloatArray, OCamlInt, OCamlInt32, OCamlInt64, OCamlList, OCamlRef, - OCamlUniformArray, ToOCaml, + alloc_error, alloc_ok, ocaml_export, ocaml_unpack_polymorphic_variant, ocaml_unpack_variant, + OCaml, OCamlBytes, OCamlException, OCamlFloat, OCamlFloatArray, OCamlInt, OCamlInt32, + OCamlInt64, OCamlList, OCamlRef, OCamlUniformArray, ToOCaml, }; use std::{thread, time}; @@ -177,6 +177,36 @@ ocaml_export! { s.to_ocaml(cr) } + fn rust_call_ocaml_closure(cr, ocaml_function: OCamlRef OCamlInt>) -> OCaml> { + let ocaml_function = ocaml_function.to_boxroot(cr); + + let call_result: Result = + ocaml_function + .try_call(cr, &0i64) + .map(|call_result| call_result.to_rust()) + .map_err(|exception| exception.message().unwrap_or("no message".to_string())); + call_result.to_ocaml(cr) + } + + fn rust_call_ocaml_closure_and_return_exn(cr, ocaml_function: OCamlRef OCamlInt>) -> OCaml> { + let ocaml_function = ocaml_function.to_boxroot(cr); + + let call_result: Result, OCaml> = + ocaml_function + .try_call(cr, &0i64); + + match call_result { + Ok(value) => { + let ocaml_value = value.root(); + alloc_ok(cr, &ocaml_value) + }, + Err(error) => { + let ocaml_error = error.root(); + alloc_error(cr, &ocaml_error) + } + } + } + fn rust_rust_add_7ints|rust_rust_add_7ints_byte( cr, int1: OCamlRef,