-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unify static and dynamic-rank tensor implementations #43
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add a new tensor base type in `rten_tensor::unified_tensor` which is generic over the layout and can represent either a static or dynamic rank tensor. Unifying the implementation will help to avoid unintended API differences between the two. The API is largely compatible with the existing tensor types and the same set of aliases are provided (`NdTensor*` for static rank, `Tensor*` for dynamic rank).
- Change the library entry point to export the unified tensor types / traits and aliases, under the same names as the legacy tensor types - Add missing trait import in iterator tests - Remove old `AxisIter*`, `InnerIter*` types from rten-tensor iterators
Adapt rten to the API changes arising from unifying `TensorBase` and `NdTensorBase`.
Now that the legacy implementation has been removed, we can drop the `unified_` part of the name, and also combine iterators back into a single module.
robertknight
changed the title
Unify static and dynamic rank tensor implementations
Unify static and dynamic-rank tensor implementations
Jan 21, 2024
Static-rank views are preferable when the dimensionality is known, as they are more efficient and the compiler can better catch errors.
This was added at an intermediate point during the unified tensor implementation, before `From` had been added.
This aligns with `slice_dyn`.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This mega-PR unifies the implementations of static and dynamic rank tensors in the rten-tensor crate. In other words, it merges
TensorBase
andNdTensorBase
into oneTensorBase
struct that is generic over the layout. The main motivations for this are:The API of the unified tensor types are mostly the same as before, with some small changes arising from methods now having more generic argument types or infering their result types differently, as well as resolving some cases where there were unintentional API differences. The same set of type aliases are provided for owned tensors and views as before, so most code consuming this crate only needs minor changes.
In summary:
NdTensor::from_data
now aligns withTensor::from_data
in taking(shape, data)
arguments. The previousfrom_data
API which accepted strides is now namedfrom_data_with_strides
dim
vsaxis
in method names has been straightened out so that it is at least consistentslice
method has been split into two, depending on whether a static or dynamic-rank result is required.slice::<M, _>(range)
returns a static-rank view withM
dims.slice_dyn(range)
returns a dynamic rank view.The unified tensor implementation is currently in one very large module in
rten-tensor/src/tensor.rs
. Splitting this up into smaller pieces will be done once this is merged.