Skip to content

Commit

Permalink
piou piou
Browse files Browse the repository at this point in the history
  • Loading branch information
epolack committed May 10, 2023
1 parent 88c28ca commit 0ec9ed5
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/workarounds/forwarddiff_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,14 +199,14 @@ function self_consistent_field(basis_dual::PlaneWaveBasis{T};
basis_primal = construct_value(basis_dual)
scfres = self_consistent_field(basis_primal; kwargs...)
= eltype(eltype(scfres.ψ[1]))
ψ_t = [reinterpret(reshape, Tψ, ψk) for ψk in scfres.ψ]
ψ0_t = [reinterpret(reshape, Tψ, ψk) for ψk in scfres.ψ]

## Compute external perturbation (contained in ham_dual) and from matvec with bands
Tψd = nothing
Tψd2 = nothing
Hψ_dual = let
occupation_dual = [T.(occk) for occk in scfres.occupation]
ψ_dual_t = [Complex.(T.(real(ψk)), T.(imag(ψk))) for ψk in ψ_t]
ψ_dual_t = [Complex.(T.(real(ψk)), T.(imag(ψk))) for ψk in ψ0_t]
Tψd = eltype(eltype(ψ_dual_t[1]))
ψ_dual = [reinterpret(reshape, SVector{basis_dual.model.n_components, Tψd},
ψk_dual_t) for ψk_dual_t in ψ_dual_t]
Expand All @@ -216,7 +216,8 @@ function self_consistent_field(basis_dual::PlaneWaveBasis{T};
ham_dual = energy_hamiltonian(basis_dual, ψ_dual, occupation_dual;
ρ=ρ_dual, eigenvalues=eigenvalues_dual,
εF=εF_dual).ham
ham_dual * ψ_dual
res = ham_dual * ψ_dual_t
[reinterpret(reshape, SVector{basis_dual.model.n_components, Tψd}, resk) for resk in res]
end
Hψ_dual_t = [reinterpret(reshape, Tψd, δHextψk) for δHextψk in Hψ_dual]

Expand All @@ -234,12 +235,14 @@ function self_consistent_field(basis_dual::PlaneWaveBasis{T};
DT = ForwardDiff.Dual{ForwardDiff.tagtype(T)}
δψ = [[δψik for δψik in δψi] for δψi in getfield.(δresults, :δψ)]
δψ_t = [[reinterpret(reshape, Tψd2, δψik) for δψik in δψk] for δψk in δψ]
ψ = map(ψ_t, δψ_t...) do ψk, δψk...
ψ_t = map(ψ0_t, δψ_t...) do ψk, δψk...
map(ψk, δψk...) do ψnk, δψnk...
Complex(DT(real(ψnk), real.(δψnk)),
DT(imag(ψnk), imag.(δψnk)))
end
end
ψ = [reinterpret(reshape, SVector{basis_dual.model.n_components, eltype(eltype(ψ_t[1]))}, ψk)
for ψk in ψ_t]
ρ = map((ρi, δρi...) -> DT(ρi, δρi), scfres.ρ, getfield.(δresults, :δρ)...)
eigenvalues = map(scfres.eigenvalues, getfield.(δresults, :δeigenvalues)...) do εk, δεk...
map((εnk, δεnk...) -> DT(εnk, δεnk), εk, δεk...)
Expand Down

0 comments on commit 0ec9ed5

Please sign in to comment.