Skip to content

Commit

Permalink
Return an immutable tensor in as_typed and as_type_erased
Browse files Browse the repository at this point in the history
  • Loading branch information
barakugav committed Aug 28, 2024
1 parent f9387a8 commit 4d3e3d3
Showing 1 changed file with 55 additions and 14 deletions.
69 changes: 55 additions & 14 deletions src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,26 @@ pub struct TensorBase<'a, D: Data>(
)>,
);
impl<'a, D: Data> TensorBase<'a, D> {
/// Create a new tensor in a boxed heap memory.
///
/// # Safety
///
/// The caller must obtain a mutable reference to `tensor_impl` if the tensor is mutable.
#[cfg(feature = "alloc")]
fn new_boxed(tensor_impl: &'a TensorImplBase<D>) -> Self {
unsafe fn new_boxed(tensor_impl: &'a TensorImplBase<D>) -> Self {
let impl_ = &tensor_impl.0 as *const et_c::TensorImpl;
let impl_ = impl_.cast_mut();
// Safety: the closure init the pointer
let tensor = unsafe { NonTriviallyMovable::new_boxed(|p| et_rs_c::Tensor_new(p, impl_)) };
Self(tensor, PhantomData)
}

fn new_in_storage(
/// Create a new tensor in the given storage.
///
/// # Safety
///
/// The caller must ensure that the new tensor is compatible with the given storage.
unsafe fn new_in_storage(
tensor_impl: &'a TensorImplBase<D>,
storage: Pin<&'a mut Storage<TensorBase<D>>>,
) -> Self {
Expand All @@ -217,6 +227,7 @@ impl<'a, D: Data> TensorBase<'a, D> {
/// # Safety
///
/// The caller must ensure that the new data generic is compatible with the data of the given tensor.
/// If `D2` is a mutable data type, the caller must obtains a mutable reference to the given tensor.
pub(crate) unsafe fn convert_from_ref<D2: Data>(tensor: &'a TensorBase<D2>) -> Self {
Self(
NonTriviallyMovable::from_ref(tensor.as_cpp_tensor()),
Expand Down Expand Up @@ -313,9 +324,9 @@ impl<'a, D: Data> TensorBase<'a, D> {
}

/// Get a type erased tensor referencing the same internal data as the given tensor.
pub fn as_type_erased(&self) -> TensorBase<D::TypeErased> {
// Safety: D::TypeErased is compatible with D
unsafe { TensorBase::<D::TypeErased>::convert_from_ref(self) }
pub fn as_type_erased(&self) -> TensorBase<<D::Immutable as Data>::TypeErased> {
// Safety: <D::Immutable as Data>::TypeErased is compatible with D and its immutable (we took &self)
unsafe { TensorBase::<<D::Immutable as Data>::TypeErased>::convert_from_ref(self) }
}

/// Try to convert this tensor into a typed tensor with scalar type `S`.
Expand All @@ -338,12 +349,13 @@ impl<'a, D: Data> TensorBase<'a, D> {
}

/// Try to get a typed tensor with scalar type `S` referencing the same internal data as the given tensor.
pub fn try_as_typed<S: Scalar>(&self) -> Result<TensorBase<D::Typed<S>>> {
pub fn try_as_typed<S: Scalar>(&self) -> Result<TensorBase<<D::Immutable as Data>::Typed<S>>> {
if self.scalar_type() != Some(S::TYPE) {
return Err(Error::InvalidType);
}
// Safety: the scalar type is checked, D::Typed is compatible with D
Ok(unsafe { TensorBase::<D::Typed<S>>::convert_from_ref(self) })
// Safety: the scalar type is checked, <D::Immutable as Data>::Typed<S> is compatible with D and its
// immutable (we took &self)
Ok(unsafe { TensorBase::<<D::Immutable as Data>::Typed<S>>::convert_from_ref(self) })
}

/// Get a typed tensor with scalar type `S` referencing the same internal data as the given tensor.
Expand All @@ -352,7 +364,7 @@ impl<'a, D: Data> TensorBase<'a, D> {
///
/// If the scalar type of the tensor does not match the type `S`.
#[track_caller]
pub fn as_typed<S: Scalar>(&self) -> TensorBase<D::Typed<S>> {
pub fn as_typed<S: Scalar>(&self) -> TensorBase<<D::Immutable as Data>::Typed<S>> {
self.try_as_typed().expect("Invalid type")
}
}
Expand Down Expand Up @@ -591,7 +603,8 @@ impl<'a, S: Scalar> Tensor<'a, S> {
/// See `executorch::util::Storage` for more information.
#[cfg(feature = "alloc")]
pub fn new(tensor_impl: &'a TensorImpl<S>) -> Self {
Self::new_boxed(tensor_impl)
// Safety: both Self and TensorImpl are immutable
unsafe { Self::new_boxed(tensor_impl) }
}
}
impl<S: Scalar> Storage<Tensor<'_, S>> {
Expand All @@ -601,7 +614,8 @@ impl<S: Scalar> Storage<Tensor<'_, S>> {
/// See `executorch::util::Storage` for more information.
#[allow(clippy::new_ret_no_self)]
pub fn new<'a>(self: Pin<&'a mut Self>, tensor_impl: &'a TensorImpl<S>) -> Tensor<'a, S> {
Tensor::new_in_storage(tensor_impl, self)
// Safety: both Self and TensorImpl are immutable
unsafe { Tensor::new_in_storage(tensor_impl, self) }
}
}

Expand All @@ -620,7 +634,8 @@ impl<'a, S: Scalar> TensorMut<'a, S> {
/// See `executorch::util::Storage` for more information.
#[cfg(feature = "alloc")]
pub fn new(tensor_impl: &'a mut TensorImplMut<S>) -> Self {
Self::new_boxed(tensor_impl)
// Safety: Self has a mutable data, and we indeed took a mutable reference to tensor_impl
unsafe { Self::new_boxed(tensor_impl) }
}
}
impl<S: Scalar> Storage<TensorMut<'_, S>> {
Expand All @@ -629,8 +644,12 @@ impl<S: Scalar> Storage<TensorMut<'_, S>> {
/// This function is identical to `TensorMut::new`, but it allows to create the tensor on the stack.
/// See `executorch::util::Storage` for more information.
#[allow(clippy::new_ret_no_self)]
pub fn new<'a>(self: Pin<&'a mut Self>, tensor_impl: &'a TensorImplMut<S>) -> TensorMut<'a, S> {
TensorMut::new_in_storage(tensor_impl, self)
pub fn new<'a>(
self: Pin<&'a mut Self>,
tensor_impl: &'a mut TensorImplMut<S>,
) -> TensorMut<'a, S> {
// Safety: Self has a mutable data, and we indeed took a mutable reference to tensor_impl
unsafe { TensorMut::new_in_storage(tensor_impl, self) }
}
}

Expand Down Expand Up @@ -751,16 +770,30 @@ impl<'a, S: Scalar> TensorImplMut<'a, S> {

/// A marker trait that provide information about the data type of a [`TensorBase`] and [`TensorImplBase`]
pub trait Data {
/// An immutable version of the data type.
///
/// For example, if the data type is `ViewMut<f32>`, the immutable version is `View<f32>`.
/// If the data is already immutable, the immutable version is the same as the data type.
type Immutable: Data;

/// A mutable version of the data type.
///
/// For example, if the data type is `View<f32>`, the mutable version is `ViewMut<f32>`.
/// If the data is already mutable, the mutable version is the same as the data type.
type Mutable: DataMut;

/// A type-erased version of the data type.
///
/// For example, if the data type is `View<f32>`, the type-erased version is `ViewAny`.
/// If the data is already type-erased, the type-erased version is the same as the data type.
type TypeErased: Data;

/// A typed version of the data type.
///
/// For example, if the data type is `ViewAny`, the typed version is `View<S>`.
/// If the data is already typed, the typed version is the same as the data type.
type Typed<S: Scalar>: DataTyped<Scalar = S>;

private_decl! {}
}
/// A marker trait extending [`Data`] that indicate that the data is mutable.
Expand All @@ -775,6 +808,8 @@ pub trait DataTyped: Data {
/// A marker type of typed immutable data of a tensor.
pub struct View<S: Scalar>(PhantomData<S>);
impl<S: Scalar> Data for View<S> {
type Immutable = View<S>;
type Mutable = ViewMut<S>;
type TypeErased = ViewAny;
type Typed<S2: Scalar> = View<S2>;
private_impl! {}
Expand All @@ -785,6 +820,8 @@ impl<S: Scalar> DataTyped for View<S> {
/// A marker type of typed mutable data of a tensor.
pub struct ViewMut<S: Scalar>(PhantomData<S>);
impl<S: Scalar> Data for ViewMut<S> {
type Immutable = View<S>;
type Mutable = ViewMut<S>;
type TypeErased = ViewMutAny;
type Typed<S2: Scalar> = ViewMut<S2>;
private_impl! {}
Expand All @@ -797,6 +834,8 @@ impl<S: Scalar> DataTyped for ViewMut<S> {
/// A marker type of type-erased immutable viewed data of a tensor.
pub struct ViewAny;
impl Data for ViewAny {
type Immutable = ViewAny;
type Mutable = ViewMutAny;
type TypeErased = ViewAny;
type Typed<S: Scalar> = View<S>;
private_impl! {}
Expand All @@ -805,6 +844,8 @@ impl Data for ViewAny {
/// A marker type of type-erased mutable viewed data of a tensor.
pub struct ViewMutAny;
impl Data for ViewMutAny {
type Immutable = ViewAny;
type Mutable = ViewMutAny;
type TypeErased = ViewMutAny;
type Typed<S: Scalar> = ViewMut<S>;
private_impl! {}
Expand Down

0 comments on commit 4d3e3d3

Please sign in to comment.