Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

upstream ForwardDiff rules to NNlib.jl #93

Open
avik-pal opened this issue Jul 16, 2024 · 0 comments
Open

upstream ForwardDiff rules to NNlib.jl #93

avik-pal opened this issue Jul 16, 2024 · 0 comments
Labels
good first issue Good for newcomers

Comments

@avik-pal
Copy link
Member

Code:

for op in [:conv, :depthwiseconv, :∇conv_data, :∇conv_filter]
luxlibop = Symbol("__$(op)")
@eval function NNlib.$(op)(x1::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N},
x2::AbstractArray{<:Real, N}, cdims::NNlib.ConvDims;
kwargs...) where {N, Tag, V, P}
value_fn(x) = ForwardDiff.value(Tag, x)
partial_fn(x, i) = ForwardDiff.partials(Tag, x, i)
y = $(luxlibop)(value_fn.(x1), x2, cdims; kwargs...)
dys = ntuple(i -> $(luxlibop)(partial_fn.(x1, i), x2, cdims; kwargs...), P)
partials = ForwardDiff.Partials.(tuple.(dys...))
return ForwardDiff.Dual{Tag, V, P}.(y, partials)
end
@eval function NNlib.$(op)(x1::AbstractArray{<:Real, N},
x2::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N},
cdims::NNlib.ConvDims; kwargs...) where {N, Tag, V, P}
value_fn(x) = ForwardDiff.value(Tag, x)
partial_fn(x, i) = ForwardDiff.partials(Tag, x, i)
y = $(luxlibop)(x1, value_fn.(x2), cdims; kwargs...)
dys = ntuple(i -> $(luxlibop)(x1, partial_fn.(x2, i), cdims; kwargs...), P)
partials = ForwardDiff.Partials.(tuple.(dys...))
return ForwardDiff.Dual{Tag, V, P}.(y, partials)
end
@eval function NNlib.$(op)(x1::AbstractArray{<:ForwardDiff.Dual{Tag, Vₓ, P}, N},
x2::AbstractArray{<:ForwardDiff.Dual{Tag, Vₚ, P}, N},
cdims::NNlib.ConvDims; kwargs...) where {N, Tag, Vₓ, Vₚ, P}
value_fn(x) = ForwardDiff.value(Tag, x)
partial_fn(x, i) = ForwardDiff.partials(Tag, x, i)
x1_data, x2_data = value_fn.(x1), value_fn.(x2)
y = $(luxlibop)(x1_data, x2_data, cdims; kwargs...)
dys₁ = ntuple(P) do i
dys₁ᵢ = $(luxlibop)(partial_fn.(x1, i), x2_data, cdims; kwargs...)
dys₂ᵢ = $(luxlibop)(x1_data, partial_fn.(x2, i), cdims; kwargs...)
dys₁ᵢ .+= dys₂ᵢ
return dys₁ᵢ
end
partials = ForwardDiff.Partials.(tuple.(dys₁...))
return ForwardDiff.Dual{Tag, promote_type(Vₓ, Vₚ), P}.(y, partials)
end
end

Tests: https://github.com/LuxDL/LuxLib.jl/blob/main/test/others/forwarddiff_tests.jl

The main reason I haven't been able to do it is because of how the tests are structured which forces you to copy the code for different backends

@avik-pal avik-pal added the good first issue Good for newcomers label Jul 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

1 participant