Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 66146a8

Browse files
committed
Use custom vjp
1 parent b7424c8 commit 66146a8

File tree

2 files changed

+151
-16
lines changed

2 files changed

+151
-16
lines changed

src/ad.jl

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,39 @@ function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs..
7474

7575
uu = sol.u
7676

77-
if !SciMLBase.has_jac(prob.f)
77+
# First check for custom `vjp` then custom `Jacobian` and if nothing is provided use
78+
# nested autodiff as the last resort
79+
if SciMLBase.has_vjp(prob.f)
80+
if isinplace(prob)
81+
_F = @closure (du, u, p) -> begin
82+
resid = similar(du, length(sol.resid))
83+
prob.f(resid, u, p)
84+
prob.f.vjp(du, resid, u, p)
85+
du .*= 2
86+
return nothing
87+
end
88+
else
89+
_F = @closure (u, p) -> begin
90+
resid = prob.f(u, p)
91+
return reshape(2 .* prob.f.vjp(resid, u, p), size(u))
92+
end
93+
end
94+
elseif SciMLBase.has_jac(prob.f)
95+
if isinplace(prob)
96+
_F = @closure (du, u, p) -> begin
97+
J = similar(du, length(sol.resid), length(u))
98+
prob.f.jac(J, u, p)
99+
resid = similar(du, length(sol.resid))
100+
prob.f(resid, u, p)
101+
mul!(reshape(du, 1, :), vec(resid)', J, 2, false)
102+
return nothing
103+
end
104+
else
105+
_F = @closure (u, p) -> begin
106+
return reshape(2 .* vec(prob.f(u, p))' * prob.f.jac(u, p), size(u))
107+
end
108+
end
109+
else
78110
if isinplace(prob)
79111
_F = @closure (du, u, p) -> begin
80112
resid = similar(du, length(sol.resid))
@@ -103,21 +135,6 @@ function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs..
103135
end
104136
end
105137
end
106-
else
107-
if isinplace(prob)
108-
_F = @closure (du, u, p) -> begin
109-
J = similar(du, length(sol.resid), length(u))
110-
prob.jac(J, u, p)
111-
resid = similar(du, length(sol.resid))
112-
prob.f(resid, u, p)
113-
mul!(reshape(du, 1, :), vec(resid)', J, 2, false)
114-
return nothing
115-
end
116-
else
117-
_F = @closure (u, p) -> begin
118-
return reshape(2 .* vec(prob.f(u, p))' * prob.jac(u, p), size(u))
119-
end
120-
end
121138
end
122139

123140
f_p = __nlsolve_∂f_∂p(prob, _F, uu, p)

test/core/forward_ad_tests.jl

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,121 @@ end
8888
end
8989
end
9090
end
91+
92+
@testsetup module ForwardADNLLSTesting
93+
using Reexport
94+
@reexport using ForwardDiff, FiniteDiff, SimpleNonlinearSolve, StaticArrays, LinearAlgebra,
95+
Zygote
96+
97+
true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4])
98+
99+
const θ_true = [1.0, 0.1, 2.0, 0.5]
100+
const x = [-1.0, -0.5, 0.0, 0.5, 1.0]
101+
const y_target = true_function(x, θ_true)
102+
103+
function loss_function(θ, p)
104+
= true_function(p, θ)
105+
return.- y_target
106+
end
107+
108+
function loss_function_jac(θ, p)
109+
return ForwardDiff.jacobian-> loss_function(θ, p), θ)
110+
end
111+
112+
loss_function_vjp(v, θ, p) = reshape(vec(v)' * loss_function_jac(θ, p), size(θ))
113+
114+
function loss_function!(resid, θ, p)
115+
= true_function(p, θ)
116+
@. resid =- y_target
117+
return
118+
end
119+
120+
function loss_function_jac!(J, θ, p)
121+
J .= ForwardDiff.jacobian-> loss_function(θ, p), θ)
122+
return
123+
end
124+
125+
function loss_function_vjp!(vJ, v, θ, p)
126+
vec(vJ) .= reshape(vec(v)' * loss_function_jac(θ, p), size(θ))
127+
return
128+
end
129+
130+
θ_init = θ_true .+ 0.1
131+
132+
export loss_function, loss_function!, loss_function_jac, loss_function_vjp,
133+
loss_function_jac!, loss_function_vjp!, θ_init, x, y_target
134+
end
135+
136+
@testitem "ForwardDiff.jl Integration: NLLS" setup=[ForwardADNLLSTesting] begin
137+
@testset "$(nameof(typeof(alg)))" for alg in (
138+
SimpleNewtonRaphson(), SimpleGaussNewton(),
139+
SimpleNewtonRaphson(AutoFiniteDiff()), SimpleGaussNewton(AutoFiniteDiff()))
140+
function obj_1(p)
141+
prob_oop = NonlinearLeastSquaresProblem{false}(loss_function, θ_init, p)
142+
sol = solve(prob_oop, alg)
143+
return sum(abs2, sol.u)
144+
end
145+
146+
function obj_2(p)
147+
ff = NonlinearFunction{false}(loss_function; jac = loss_function_jac)
148+
prob_oop = NonlinearLeastSquaresProblem{false}(ff, θ_init, p)
149+
sol = solve(prob_oop, alg)
150+
return sum(abs2, sol.u)
151+
end
152+
153+
function obj_3(p)
154+
ff = NonlinearFunction{false}(loss_function; vjp = loss_function_vjp)
155+
prob_oop = NonlinearLeastSquaresProblem{false}(ff, θ_init, p)
156+
sol = solve(prob_oop, alg)
157+
return sum(abs2, sol.u)
158+
end
159+
160+
finitediff = FiniteDiff.finite_difference_gradient(obj_1, x)
161+
162+
fdiff1 = ForwardDiff.gradient(obj_1, x)
163+
fdiff2 = ForwardDiff.gradient(obj_2, x)
164+
fdiff3 = ForwardDiff.gradient(obj_3, x)
165+
166+
@test finitedifffdiff1 atol=1e-5
167+
@test finitedifffdiff2 atol=1e-5
168+
@test finitedifffdiff3 atol=1e-5
169+
@test fdiff1 fdiff2 fdiff3
170+
171+
function obj_4(p)
172+
prob_iip = NonlinearLeastSquaresProblem(
173+
NonlinearFunction{true}(
174+
loss_function!; resid_prototype = zeros(length(y_target))), θ_init, p)
175+
sol = solve(prob_iip, alg)
176+
return sum(abs2, sol.u)
177+
end
178+
179+
function obj_5(p)
180+
ff = NonlinearFunction{true}(
181+
loss_function!; resid_prototype = zeros(length(y_target)), jac = loss_function_jac!)
182+
prob_iip = NonlinearLeastSquaresProblem(
183+
ff, θ_init, p)
184+
sol = solve(prob_iip, alg)
185+
return sum(abs2, sol.u)
186+
end
187+
188+
function obj_6(p)
189+
ff = NonlinearFunction{true}(
190+
loss_function!; resid_prototype = zeros(length(y_target)), vjp = loss_function_vjp!)
191+
prob_iip = NonlinearLeastSquaresProblem(
192+
ff, θ_init, p)
193+
sol = solve(prob_iip, alg)
194+
return sum(abs2, sol.u)
195+
end
196+
197+
finitediff = FiniteDiff.finite_difference_gradient(obj_4, x)
198+
199+
fdiff4 = ForwardDiff.gradient(obj_4, x)
200+
fdiff5 = ForwardDiff.gradient(obj_5, x)
201+
fdiff6 = ForwardDiff.gradient(obj_6, x)
202+
203+
@test finitedifffdiff4 atol=1e-5
204+
@test finitedifffdiff5 atol=1e-5
205+
@test finitedifffdiff6 atol=1e-5
206+
@test fdiff4 fdiff5 fdiff6
207+
end
208+
end

0 commit comments

Comments
 (0)