Skip to content

Commit

Permalink
Merge pull request #34 from bmad-sim/strict1
Browse files Browse the repository at this point in the history
strict TaylorMap sanity and simplification of code
  • Loading branch information
mattsignorelli authored Jun 24, 2024
2 parents c9f732f + 3cac710 commit 70932c8
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 35 deletions.
60 changes: 32 additions & 28 deletions src/map/ctors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ function $t{S,T,U,V,W}(u::UndefInitializer; use=GTPSA.desc_current, idpt::W=noth
end

"""
$($t)(; x::Union{Vector{<:Union{TPS,ComplexTPS}},Nothing}=nothing, x0::Union{Vector,Nothing}=nothing, Q::Union{Quaternion{<:Union{TPS,ComplexTPS}},Nothing}=nothing, E::Union{Matrix,Nothing}=nothing, spin::Union{Bool,Nothing}=nothing, FD::Union{Bool,Nothing}=nothing, idpt::Union{Nothing,Bool}=nothing, use=nothing)
$($t)(;use=GTPSA.desc_current, x::Vector=vars(getdesc(use)), x0::Vector=zeros(numtype(eltype(x)), numvars(use)), Q::Union{Quaternion,Nothing}=nothing, E::Union{Matrix,Nothing}=nothing, idpt::Union{Bool,Nothing}=nothing, spin::Union{Bool,Nothing}=nothing, FD::Union{Bool,Nothing}=nothing)
Constructs a $($t) with the passed vector of `TPS`/`ComplexTPS` as the orbital ray, and optionally the entrance
coordinates `x0`, `Quaternion` for spin `Q`, and FD matrix `E` as keyword arguments. The helper keyword
Expand All @@ -108,43 +108,47 @@ specified is type-unstable. This constructor also checks for consistency in the
`Descriptor`. The `use` kwarg may also be used to change the `Descriptor` of the TPSs, provided the number of variables
+ parameters agree (orders may be different).
"""
function $t(;use=GTPSA.desc_current, x::Vector{T}=vars(getdesc(use)), x0::Vector{S}=zeros(numtype(eltype(x)), numvars(use)), Q::U=nothing, E::V=nothing, idpt::W=nothing, spin::Union{Bool,Nothing}=nothing, FD::Union{Bool,Nothing}=nothing) where {S,T<:Union{TPS,ComplexTPS},U<:Union{Quaternion{<:Union{TPS,ComplexTPS}},Nothing},V<:Union{Matrix,Nothing},W<:Union{Nothing,Bool}}
function $t(;use=GTPSA.desc_current, x::Vector=vars(getdesc(use)), x0::Vector=zeros(numtype(eltype(x)), numvars(use)), Q::Union{Quaternion,Nothing}=nothing, E::Union{Matrix,Nothing}=nothing, idpt::Union{Bool,Nothing}=nothing, spin::Union{Bool,Nothing}=nothing, FD::Union{Bool,Nothing}=nothing)
if !isnothing(Q)
if !isnothing(E)
T = promote_type(TPS,eltype(x0),eltype(x),eltype(Q),eltype(E))
else
T = promote_type(TPS,eltype(x0),eltype(x),eltype(Q))
end
else
T = promote_type(TPS,eltype(x0),eltype(x))
end

S = numtype(T)

# set up
if isnothing(spin)
if isnothing(Q)
outU = Nothing
U = Nothing
else
outU = typeof(Q)
U = Quaternion{T}
end
elseif spin
if isnothing(Q)
outU = Quaternion{T}
else
outU = typeof(Q)
end
U = Quaternion{T}
else
error("For no spin tracking, please omit the spin kwarg or set spin=nothing") # For type stability
#outU = Nothing # For type instability
#U = Nothing # For type instability
end

if isnothing(FD)
if isnothing(E)
outV = Nothing
V = Nothing
else
outV = typeof(E)
V = Matrix{S}
end
elseif FD
if isnothing(E)
outV = Matrix{numtype(T)}
else
outV = typeof(E)
end
V = Matrix{S}
else
error("For no fluctuation-dissipation, please omit the FD kwarg or set FD=nothing") # For type stability
#outV = Nothing # For type instability
#V = Nothing # For type instability
end

outm = $t{S,T,outU,outV,typeof(idpt)}(undef, use = use, idpt = idpt)
outm = $t{S,T,U,V,typeof(idpt)}(undef, use = use, idpt = idpt)

nv = numvars(use)
np = numparams(use)
Expand All @@ -164,11 +168,11 @@ function $t(;use=GTPSA.desc_current, x::Vector{T}=vars(getdesc(use)), x0::Vector
if !isnothing(outm.Q)
if !isnothing(Q)
for i=1:4
@inbounds outm.Q.q[i] = eltype(outU)(Q.q[i], use=getdesc(use))
@inbounds outm.Q.q[i] = T(Q.q[i], use=getdesc(use))
end
else
for i=1:4
@inbounds outm.Q.q[i] = (eltype(outU))(use=getdesc(use))
@inbounds outm.Q.q[i] = T(use=getdesc(use))
end
outm.Q.q[1][0] = 1
end
Expand Down Expand Up @@ -198,28 +202,28 @@ matrix, or `false` for no spin/FD. Note that setting `spin`/`FD` to any `Bool` v
specified is type-unstable. This constructor also checks for consistency in the length of the orbital ray and GTPSA
`Descriptor`.
"""
function $t(M::AbstractMatrix; use=GTPSA.desc_current, x::Vector{T}=vars(getdesc(use)), x0::Vector{S}=zeros(numtype(eltype(x)), numvars(use)), Q::U=nothing, E::V=nothing, idpt::W=nothing, spin::Union{Bool,Nothing}=nothing, FD::Union{Bool,Nothing}=nothing) where {S,T<:Union{TPS,ComplexTPS},U<:Union{Quaternion{<:Union{TPS,ComplexTPS}},Nothing},V<:Union{Matrix,Nothing},W<:Union{Nothing,Bool}}
function $t(M::AbstractMatrix; use=GTPSA.desc_current, x0::Vector=zeros(numtype(eltype(x)), numvars(use)), Q::Union{Quaternion,Nothing}=nothing, E::Union{Matrix,Nothing}=nothing, idpt::Union{Bool,Nothing}=nothing, spin::Union{Bool,Nothing}=nothing, FD::Union{Bool,Nothing}=nothing)
Base.require_one_based_indexing(M)
nv = numvars(use)
nn = numnn(use)

nv >= size(M,1) || error("Number of rows in matrix > number of variables in GTPSA!")

if eltype(M) <: Complex
outT = ComplexTPS
T = ComplexTPS
else
outT = TPS
T = TPS
end

x = Vector{outT}(undef, nv)
x = Vector{T}(undef, nv)
for i=1:size(M,1)
@inbounds x[i] = (outT)(use=getdesc(use))
@inbounds x[i] = T(use=getdesc(use))
for j=1:size(M,2)
@inbounds x[i][j] = M[i,j]
end
end

return DAMap(use=use, x=x, x0=x0,Q=Q,E=E,idpt=idpt,spin=spin,FD=FD)
return DAMap(use=use,x=x,x0=x0,Q=Q,E=E,idpt=idpt,spin=spin,FD=FD)
end


Expand Down
20 changes: 19 additions & 1 deletion src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ struct DAMap{S,T<:Union{TPS,ComplexTPS},U<:Union{Quaternion{T},Nothing},V<:Union
Q::U # Quaternion for spin
E::V # Envelope for stochasticity
idpt::W # Specifies index of constant (energy-like) variable

function DAMap{S,T,U,V,W}(x0, x, Q, E, idpt) where {S,T,U,V,W}
m = new{S,T,U,V,W}(x0, x, Q, E, idpt)
checkmapsanity(m)
return m
end
end

"""
Expand All @@ -50,6 +56,12 @@ struct TPSAMap{S,T<:Union{TPS,ComplexTPS},U<:Union{Quaternion{T},Nothing},V<:Uni
Q::U # Quaternion for spin
E::V # Envelope for stochasticity
idpt::W # Specifies index of constant (energy-like) variable

function TPSAMap{S,T,U,V,W}(x0, x, Q, E, idpt) where {S,T,U,V,W}
m = new{S,T,U,V,W}(x0, x, Q, E, idpt)
checkmapsanity(m)
return m
end
end

"""
Expand Down Expand Up @@ -93,13 +105,19 @@ for t = (:DAMap, :TPSAMap)
function promote_rule(::Type{$t{S,T,U,V,W}}, ::Type{G}) where {S,T,U,V,W,G<:Union{Number,Complex}}
outS = promote_type(S,numtype(T),G)
outT = promote_type(T,G)
println("hi")
U != Nothing ? outU = Quaternion{promote_type(eltype(U), G)} : outU = Nothing
V != Nothing ? outV = promote_type(Matrix{G},V) : outV = Nothing
return $t{outS,outT,outU,outV,W}
end

# Currently promote_type in promotion.jl gives
# promote_type(::Type{T}, ::Type{T}) where {T} = T
# and does not even call promote_rule, therefore this is never reached
# Therefore I will required the reference orbit to have the same numtype as the
# TPS at construction.
function promote_rule(::Type{$t{S1,T1,U1,V1,W}}, ::Type{$t{S2,T2,U2,V2,W}}) where {S1,S2,T1,T2,U1,U2,V1,V2,W}
outS = promote_type(S1, S2, numtype(T1), numtype(T2))
outS = promote_type(numtype(T1), numtype(T2))
outT = promote_type(T1, T2)
U1 != Nothing ? outU = Quaternion{promote_type(eltype(U2),eltype(U2))} : outU = Nothing
V1 != Nothing ? outV = promote_type(V1,V2) : outV = Nothing
Expand Down
6 changes: 6 additions & 0 deletions src/utils/misc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ pords(m::Union{Probe{<:Real,<:Union{TPS,ComplexTPS},<:Any,<:Any},<:TaylorMap,Vec
return true
end

@inline function checkmapsanity(m::TaylorMap{S,T,U,V,W}) where {S,T,U,V,W}
S == numtype(T) || error("Reference orbit type $S must be $(numtype(T)) (equal to scalar of orbital)")
# already checked in type: U == Nothing || eltype(U) == T || error("Quaternion type $(eltype(U)) must be $T (equal to orbital)")
V == Nothing || eltype(V) == numtype(T) || error("Stochastic matrix type $(eltype(V)) must be $(numtype(T)) (equal to scalar of orbital)")
end

# --- random symplectic map ---
function rand(t::Union{Type{DAMap},Type{TPSAMap}}; spin::Union{Bool,Nothing}=nothing, FD::Union{Bool,Nothing}=nothing, use::Union{Descriptor,TPS,ComplexTPS}=GTPSA.desc_current, ndpt::Union{Nothing,Integer}=nothing)
if isnothing(spin)
Expand Down
12 changes: 6 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@ using Test
@testset "NonlinearNormalForm.jl" begin
d = Descriptor(1,2)
x1 = vars()[1]
m1 = DAMap(x=[1+2*x1+2*x1^2], x0=[4.])
m2 = DAMap(x=[1+2*x1+2*x1^2], x0=[3.])
m1 = DAMap(x=[1+2*x1+2*x1^2], x0=[4])
m2 = DAMap(x=[1+2*x1+2*x1^2], x0=[3])

mt1 = TPSAMap(m1)
mt2 = TPSAMap(m2)

tol = 1e-10

@test norm(m2m1 - DAMap(x=[1+4*x1+12*x1^2], x0=[4.])) < tol
@test norm(mt2mt1 - TPSAMap(x=[5-12*x1-4*x1^2], x0=[4.])) < tol
@test norm(m2m1 - DAMap(x=[1+4*x1+12*x1^2], x0=[4])) < tol
@test norm(mt2mt1 - TPSAMap(x=[5-12*x1-4*x1^2], x0=[4])) < tol
@test norm(m2^3 - m2m2m2) < tol
@test norm(mt2^3 - mt2mt2mt2) < tol
@test norm(m2^3 - DAMap(x=[1+8*x1+56*x1^2], x0=[3.])) < tol
@test norm(mt2^3 - TPSAMap(x=[13+8*x1-8*x1^2], x0=[3.])) < tol
@test norm(m2^3 - DAMap(x=[1+8*x1+56*x1^2], x0=[3])) < tol
@test norm(mt2^3 - TPSAMap(x=[13+8*x1-8*x1^2], x0=[3])) < tol

# with temporary inverter"
#@test norm(m1^-3 - DAMap(x=[4+0.125*x1-0.109375*x1^2],x0=[1])) < tol
Expand Down

0 comments on commit 70932c8

Please sign in to comment.