From 0ec9ed58c7467705f7cdab8b44bf91eb908ae747 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Polack?= Date: Wed, 10 May 2023 15:07:01 +0200 Subject: [PATCH] piou piou --- src/workarounds/forwarddiff_rules.jl | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/workarounds/forwarddiff_rules.jl b/src/workarounds/forwarddiff_rules.jl index f58443af23..859c18c98d 100644 --- a/src/workarounds/forwarddiff_rules.jl +++ b/src/workarounds/forwarddiff_rules.jl @@ -199,14 +199,14 @@ function self_consistent_field(basis_dual::PlaneWaveBasis{T}; basis_primal = construct_value(basis_dual) scfres = self_consistent_field(basis_primal; kwargs...) Tψ = eltype(eltype(scfres.ψ[1])) - ψ_t = [reinterpret(reshape, Tψ, ψk) for ψk in scfres.ψ] + ψ0_t = [reinterpret(reshape, Tψ, ψk) for ψk in scfres.ψ] ## Compute external perturbation (contained in ham_dual) and from matvec with bands Tψd = nothing Tψd2 = nothing Hψ_dual = let occupation_dual = [T.(occk) for occk in scfres.occupation] - ψ_dual_t = [Complex.(T.(real(ψk)), T.(imag(ψk))) for ψk in ψ_t] + ψ_dual_t = [Complex.(T.(real(ψk)), T.(imag(ψk))) for ψk in ψ0_t] Tψd = eltype(eltype(ψ_dual_t[1])) ψ_dual = [reinterpret(reshape, SVector{basis_dual.model.n_components, Tψd}, ψk_dual_t) for ψk_dual_t in ψ_dual_t] @@ -216,7 +216,8 @@ function self_consistent_field(basis_dual::PlaneWaveBasis{T}; ham_dual = energy_hamiltonian(basis_dual, ψ_dual, occupation_dual; ρ=ρ_dual, eigenvalues=eigenvalues_dual, εF=εF_dual).ham - ham_dual * ψ_dual + res = ham_dual * ψ_dual_t + [reinterpret(reshape, SVector{basis_dual.model.n_components, Tψd}, resk) for resk in res] end Hψ_dual_t = [reinterpret(reshape, Tψd, δHextψk) for δHextψk in Hψ_dual] @@ -234,12 +235,14 @@ function self_consistent_field(basis_dual::PlaneWaveBasis{T}; DT = ForwardDiff.Dual{ForwardDiff.tagtype(T)} δψ = [[δψik for δψik in δψi] for δψi in getfield.(δresults, :δψ)] δψ_t = [[reinterpret(reshape, Tψd2, δψik) for δψik in δψk] for δψk in δψ] - ψ = map(ψ_t, δψ_t...) do ψk, δψk... + ψ_t = map(ψ0_t, δψ_t...) do ψk, δψk... map(ψk, δψk...) do ψnk, δψnk... Complex(DT(real(ψnk), real.(δψnk)), DT(imag(ψnk), imag.(δψnk))) end end + ψ = [reinterpret(reshape, SVector{basis_dual.model.n_components, eltype(eltype(ψ_t[1]))}, ψk) + for ψk in ψ_t] ρ = map((ρi, δρi...) -> DT(ρi, δρi), scfres.ρ, getfield.(δresults, :δρ)...) eigenvalues = map(scfres.eigenvalues, getfield.(δresults, :δeigenvalues)...) do εk, δεk... map((εnk, δεnk...) -> DT(εnk, δεnk), εk, δεk...)