Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify static and dynamic-rank tensor implementations #43

Merged
merged 12 commits into from
Jan 22, 2024

Conversation

robertknight
Copy link
Owner

@robertknight robertknight commented Jan 21, 2024

This mega-PR unifies the implementations of static and dynamic rank tensors in the rten-tensor crate. In other words, it merges TensorBase and NdTensorBase into one TensorBase struct that is generic over the layout. The main motivations for this are:

  • To avoid API inconsistency between the two types, which had crept in over time
  • To reduce code duplication

The API of the unified tensor types are mostly the same as before, with some small changes arising from methods now having more generic argument types or infering their result types differently, as well as resolving some cases where there were unintentional API differences. The same set of type aliases are provided for owned tensors and views as before, so most code consuming this crate only needs minor changes.

In summary:

  • NdTensor::from_data now aligns with Tensor::from_data in taking (shape, data) arguments. The previous from_data API which accepted strides is now named from_data_with_strides
  • Methods for broadcasting and reshaping tensors now infer the result layout from the shape argument
  • The use of dim vs axis in method names has been straightened out so that it is at least consistent
  • The slice method has been split into two, depending on whether a static or dynamic-rank result is required. slice::<M, _>(range) returns a static-rank view with M dims. slice_dyn(range) returns a dynamic rank view.

The unified tensor implementation is currently in one very large module in rten-tensor/src/tensor.rs. Splitting this up into smaller pieces will be done once this is merged.

Add a new tensor base type in `rten_tensor::unified_tensor` which is generic
over the layout and can represent either a static or dynamic rank tensor.
Unifying the implementation will help to avoid unintended API differences
between the two.

The API is largely compatible with the existing tensor types and the same set of
aliases are provided (`NdTensor*` for static rank, `Tensor*` for dynamic rank).
 - Change the library entry point to export the unified tensor types /
   traits and aliases, under the same names as the legacy tensor types

 - Add missing trait import in iterator tests

 - Remove old `AxisIter*`, `InnerIter*` types from rten-tensor iterators
Adapt rten to the API changes arising from unifying `TensorBase` and
`NdTensorBase`.
Now that the legacy implementation has been removed, we can drop the `unified_`
part of the name, and also combine iterators back into a single module.
@robertknight robertknight changed the title Unify static and dynamic rank tensor implementations Unify static and dynamic-rank tensor implementations Jan 21, 2024
Static-rank views are preferable when the dimensionality is known, as they are
more efficient and the compiler can better catch errors.
This was added at an intermediate point during the unified tensor
implementation, before `From` had been added.
This aligns with `slice_dyn`.
@robertknight robertknight merged commit 84985ec into main Jan 22, 2024
1 check passed
@robertknight robertknight deleted the uni-tensor-migration branch January 22, 2024 07:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant