Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Preserve Indices When Copying Tracked Arrays #263

Merged
merged 4 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/similar_convert_copy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ end
function Base.convert(::Type{ComponentArray{T1,N,A1,Ax1}}, x::ComponentArray{T2,N,A2,Ax2}) where {T1,T2,N,A1,A2,Ax1,Ax2}
return T1.(x)
end
function Base.convert(::Type{ComponentArray{T,N,A1,Ax1}}, x::ComponentArray{T,N,A2,Ax2}) where {T,N,A1,A2,Ax1,Ax2}
return x
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't we want the axes to be from the type we're converting to? I think part of the contract of convert is that the output needs to be of the type you're converting to.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It follows the same pattern as some of the other dispatches

function Base.convert(::Type{ComponentArray{T,N,A,Ax1}}, x::ComponentArray{T,N,A,Ax2}) where {T,N,A,Ax1,Ax2}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in those dispatches we'd want to just have Ax instead of Ax1 and Ax2 too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed! I do not know if there are downstream packages relying on that behaviour, so I will not touch those dispatches in this PR. For now, I have matched the axes for this dispatch

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I definitely need to change those, then. But that doesn't block this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have also had to add another trivial dispatch to avoid ambiguities.

end
Base.convert(T::Type{<:Array}, x::ComponentArray) = convert(T, getdata(x))

Base.convert(::Type{Cholesky{T1,Matrix{T1}}}, x::Cholesky{T2,<:ComponentArray}) where {T1,T2} = Cholesky(Matrix{T1}(x.factors), x.uplo, x.info)
Expand Down
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using ComponentArrays
using BenchmarkTools
using ForwardDiff
using Tracker
using InvertedIndices
using LabelledArrays
using LinearAlgebra
Expand Down Expand Up @@ -400,6 +401,10 @@ end

@test convert(Array, ca) == getdata(ca)
@test convert(Matrix{Float32}, cmat) isa Matrix{Float32}

tr = Tracker.param(ca)
ca_ = convert(typeof(ca), tr)
@test ca_.a == ca.a
end

@testset "Broadcasting" begin
Expand Down
Loading