|
88 | 88 | end
|
89 | 89 | end
|
90 | 90 | end
|
| 91 | + |
| 92 | +@testsetup module ForwardADNLLSTesting |
| 93 | +using Reexport |
| 94 | +@reexport using ForwardDiff, FiniteDiff, SimpleNonlinearSolve, StaticArrays, LinearAlgebra, |
| 95 | + Zygote |
| 96 | + |
| 97 | +true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4]) |
| 98 | + |
| 99 | +const θ_true = [1.0, 0.1, 2.0, 0.5] |
| 100 | +const x = [-1.0, -0.5, 0.0, 0.5, 1.0] |
| 101 | +const y_target = true_function(x, θ_true) |
| 102 | + |
| 103 | +function loss_function(θ, p) |
| 104 | + ŷ = true_function(p, θ) |
| 105 | + return ŷ .- y_target |
| 106 | +end |
| 107 | + |
| 108 | +function loss_function_jac(θ, p) |
| 109 | + return ForwardDiff.jacobian(θ -> loss_function(θ, p), θ) |
| 110 | +end |
| 111 | + |
| 112 | +loss_function_vjp(v, θ, p) = reshape(vec(v)' * loss_function_jac(θ, p), size(θ)) |
| 113 | + |
| 114 | +function loss_function!(resid, θ, p) |
| 115 | + ŷ = true_function(p, θ) |
| 116 | + @. resid = ŷ - y_target |
| 117 | + return |
| 118 | +end |
| 119 | + |
| 120 | +function loss_function_jac!(J, θ, p) |
| 121 | + J .= ForwardDiff.jacobian(θ -> loss_function(θ, p), θ) |
| 122 | + return |
| 123 | +end |
| 124 | + |
| 125 | +function loss_function_vjp!(vJ, v, θ, p) |
| 126 | + vec(vJ) .= reshape(vec(v)' * loss_function_jac(θ, p), size(θ)) |
| 127 | + return |
| 128 | +end |
| 129 | + |
| 130 | +θ_init = θ_true .+ 0.1 |
| 131 | + |
| 132 | +export loss_function, loss_function!, loss_function_jac, loss_function_vjp, |
| 133 | + loss_function_jac!, loss_function_vjp!, θ_init, x, y_target |
| 134 | +end |
| 135 | + |
| 136 | +@testitem "ForwardDiff.jl Integration: NLLS" setup=[ForwardADNLLSTesting] begin |
| 137 | + @testset "$(nameof(typeof(alg)))" for alg in ( |
| 138 | + SimpleNewtonRaphson(), SimpleGaussNewton(), |
| 139 | + SimpleNewtonRaphson(AutoFiniteDiff()), SimpleGaussNewton(AutoFiniteDiff())) |
| 140 | + function obj_1(p) |
| 141 | + prob_oop = NonlinearLeastSquaresProblem{false}(loss_function, θ_init, p) |
| 142 | + sol = solve(prob_oop, alg) |
| 143 | + return sum(abs2, sol.u) |
| 144 | + end |
| 145 | + |
| 146 | + function obj_2(p) |
| 147 | + ff = NonlinearFunction{false}(loss_function; jac = loss_function_jac) |
| 148 | + prob_oop = NonlinearLeastSquaresProblem{false}(ff, θ_init, p) |
| 149 | + sol = solve(prob_oop, alg) |
| 150 | + return sum(abs2, sol.u) |
| 151 | + end |
| 152 | + |
| 153 | + function obj_3(p) |
| 154 | + ff = NonlinearFunction{false}(loss_function; vjp = loss_function_vjp) |
| 155 | + prob_oop = NonlinearLeastSquaresProblem{false}(ff, θ_init, p) |
| 156 | + sol = solve(prob_oop, alg) |
| 157 | + return sum(abs2, sol.u) |
| 158 | + end |
| 159 | + |
| 160 | + finitediff = FiniteDiff.finite_difference_gradient(obj_1, x) |
| 161 | + |
| 162 | + fdiff1 = ForwardDiff.gradient(obj_1, x) |
| 163 | + fdiff2 = ForwardDiff.gradient(obj_2, x) |
| 164 | + fdiff3 = ForwardDiff.gradient(obj_3, x) |
| 165 | + |
| 166 | + @test finitediff≈fdiff1 atol=1e-5 |
| 167 | + @test finitediff≈fdiff2 atol=1e-5 |
| 168 | + @test finitediff≈fdiff3 atol=1e-5 |
| 169 | + @test fdiff1 ≈ fdiff2 ≈ fdiff3 |
| 170 | + |
| 171 | + function obj_4(p) |
| 172 | + prob_iip = NonlinearLeastSquaresProblem( |
| 173 | + NonlinearFunction{true}( |
| 174 | + loss_function!; resid_prototype = zeros(length(y_target))), θ_init, p) |
| 175 | + sol = solve(prob_iip, alg) |
| 176 | + return sum(abs2, sol.u) |
| 177 | + end |
| 178 | + |
| 179 | + function obj_5(p) |
| 180 | + ff = NonlinearFunction{true}( |
| 181 | + loss_function!; resid_prototype = zeros(length(y_target)), jac = loss_function_jac!) |
| 182 | + prob_iip = NonlinearLeastSquaresProblem( |
| 183 | + ff, θ_init, p) |
| 184 | + sol = solve(prob_iip, alg) |
| 185 | + return sum(abs2, sol.u) |
| 186 | + end |
| 187 | + |
| 188 | + function obj_6(p) |
| 189 | + ff = NonlinearFunction{true}( |
| 190 | + loss_function!; resid_prototype = zeros(length(y_target)), vjp = loss_function_vjp!) |
| 191 | + prob_iip = NonlinearLeastSquaresProblem( |
| 192 | + ff, θ_init, p) |
| 193 | + sol = solve(prob_iip, alg) |
| 194 | + return sum(abs2, sol.u) |
| 195 | + end |
| 196 | + |
| 197 | + finitediff = FiniteDiff.finite_difference_gradient(obj_4, x) |
| 198 | + |
| 199 | + fdiff4 = ForwardDiff.gradient(obj_4, x) |
| 200 | + fdiff5 = ForwardDiff.gradient(obj_5, x) |
| 201 | + fdiff6 = ForwardDiff.gradient(obj_6, x) |
| 202 | + |
| 203 | + @test finitediff≈fdiff4 atol=1e-5 |
| 204 | + @test finitediff≈fdiff5 atol=1e-5 |
| 205 | + @test finitediff≈fdiff6 atol=1e-5 |
| 206 | + @test fdiff4 ≈ fdiff5 ≈ fdiff6 |
| 207 | + end |
| 208 | +end |
0 commit comments