diff --git a/ext/MTKChainRulesCoreExt.jl b/ext/MTKChainRulesCoreExt.jl index 635f9d12cf..0ddb810095 100644 --- a/ext/MTKChainRulesCoreExt.jl +++ b/ext/MTKChainRulesCoreExt.jl @@ -85,13 +85,19 @@ function ChainRulesCore.rrule( f′ = NoTangent() indp′ = NoTangent() - tunable = selected_tangents(buf′.tunable, tunable_idxs) - discrete = selected_tangents(buf′.discrete, disc_idxs) - constant = selected_tangents(buf′.constant, const_idxs) - nonnumeric = selected_tangents(buf′.nonnumeric, nn_idxs) + if buf′ isa AbstractArray + tunable = selected_tangents(buf′, tunable_idxs) + discrete = constant = nonnumeric = NoTangent() + vals′ = map(i -> buf′[i.idx], idxs) + else + tunable = selected_tangents(buf′.tunable, tunable_idxs) + discrete = selected_tangents(buf′.discrete, disc_idxs) + constant = selected_tangents(buf′.constant, const_idxs) + nonnumeric = selected_tangents(buf′.nonnumeric, nn_idxs) + vals′ = map(i -> MTK._ducktyped_parameter_values(buf′, i), idxs) + end oldbuf′ = Tangent{typeof(oldbuf)}(; tunable, discrete, constant, nonnumeric) idxs′ = NoTangent() - vals′ = map(i -> MTK._ducktyped_parameter_values(buf′, i), idxs) return f′, indp′, oldbuf′, idxs′, vals′ end end