Skip to content

Commit 693949b

Browse files
committed
refactor: cleanup all wrappers
1 parent 4f74479 commit 693949b

21 files changed

+1214
-1072
lines changed

Project.toml

+6
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ BracketingNonlinearSolve = "70df07ce-3d50-431d-a3e7-ca6ddb60ac1e"
1010
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
1111
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
1212
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
13+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1314
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
1415
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1516
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
@@ -62,10 +63,12 @@ Aqua = "0.8"
6263
ArrayInterface = "7.16"
6364
BandedMatrices = "1.5"
6465
BenchmarkTools = "1.4"
66+
BracketingNonlinearSolve = "1"
6567
CUDA = "5.5"
6668
CommonSolve = "0.2.4"
6769
ConcreteStructs = "0.2.3"
6870
DiffEqBase = "6.155.3"
71+
DifferentiationInterface = "0.6.18"
6972
Enzyme = "0.13.11"
7073
ExplicitImports = "1.5"
7174
FastClosures = "0.3.2"
@@ -87,6 +90,9 @@ NLsolve = "4.5"
8790
NaNMath = "1"
8891
NonlinearProblemLibrary = "0.1.2"
8992
NonlinearSolveBase = "1"
93+
NonlinearSolveFirstOrder = "1"
94+
NonlinearSolveQuasiNewton = "1"
95+
NonlinearSolveSpectralMethods = "1"
9096
OrdinaryDiffEqTsit5 = "1.1.0"
9197
PETSc = "0.2"
9298
Pkg = "1.10"

common/common_nlls_testing.jl

+1
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,4 @@ prob_iip_vjp = NonlinearLeastSquaresProblem(
4747
)
4848

4949
export prob_oop, prob_iip, prob_oop_vjp, prob_iip_vjp
50+
export true_function, θ_true, x, y_target, loss_function, θ_init
+49-37
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,84 @@
11
module NonlinearSolveFastLevenbergMarquardtExt
22

3-
using ArrayInterface: ArrayInterface
43
using FastClosures: @closure
4+
5+
using ArrayInterface: ArrayInterface
56
using FastLevenbergMarquardt: FastLevenbergMarquardt
6-
using NonlinearSolveBase: NonlinearSolveBase, get_tolerance
7-
using NonlinearSolve: NonlinearSolve, FastLevenbergMarquardtJL
8-
using SciMLBase: SciMLBase, NonlinearLeastSquaresProblem, NonlinearProblem, ReturnCode
97
using StaticArraysCore: SArray
108

11-
const FastLM = FastLevenbergMarquardt
9+
using NonlinearSolveBase: NonlinearSolveBase
10+
using NonlinearSolve: NonlinearSolve, FastLevenbergMarquardtJL
11+
using SciMLBase: SciMLBase, AbstractNonlinearProblem, ReturnCode
1212

13-
@inline function _fast_lm_solver(::FastLevenbergMarquardtJL{linsolve}, x) where {linsolve}
14-
if linsolve === :cholesky
15-
return FastLM.CholeskySolver(ArrayInterface.undefmatrix(x))
16-
elseif linsolve === :qr
17-
return FastLM.QRSolver(eltype(x), length(x))
18-
else
19-
throw(ArgumentError("Unknown FastLevenbergMarquardt Linear Solver: $linsolve"))
20-
end
21-
end
22-
@inline _fast_lm_solver(::FastLevenbergMarquardtJL{linsolve}, ::SArray) where {linsolve} = linsolve
13+
const FastLM = FastLevenbergMarquardt
2314

24-
function SciMLBase.__solve(prob::Union{NonlinearLeastSquaresProblem, NonlinearProblem},
25-
alg::FastLevenbergMarquardtJL, args...; alias_u0 = false, abstol = nothing,
26-
reltol = nothing, maxiters = 1000, termination_condition = nothing, kwargs...)
27-
NonlinearSolve.__test_termination_condition(
28-
termination_condition, :FastLevenbergMarquardt)
15+
function SciMLBase.__solve(
16+
prob::AbstractNonlinearProblem, alg::FastLevenbergMarquardtJL, args...;
17+
alias_u0 = false, abstol = nothing, reltol = nothing, maxiters = 1000,
18+
termination_condition = nothing, kwargs...
19+
)
20+
NonlinearSolveBase.assert_extension_supported_termination_condition(
21+
termination_condition, alg
22+
)
2923

30-
fn, u, resid = NonlinearSolve.__construct_extension_f(
31-
prob; alias_u0, can_handle_oop = Val(prob.u0 isa SArray))
24+
f_wrapped, u, resid = NonlinearSolveBase.construct_extension_function_wrapper(
25+
prob; alias_u0, can_handle_oop = Val(prob.u0 isa SArray)
26+
)
3227
f = if prob.u0 isa SArray
33-
@closure (u, p) -> fn(u)
28+
@closure (u, p) -> f_wrapped(u)
3429
else
35-
@closure (du, u, p) -> fn(du, u)
30+
@closure (du, u, p) -> f_wrapped(du, u)
3631
end
37-
abstol = get_tolerance(abstol, eltype(u))
38-
reltol = get_tolerance(reltol, eltype(u))
3932

40-
_jac_fn = NonlinearSolve.__construct_extension_jac(
41-
prob, alg, u, resid; alg.autodiff, can_handle_oop = Val(prob.u0 isa SArray))
33+
abstol = NonlinearSolveBase.get_tolerance(abstol, eltype(u))
34+
reltol = NonlinearSolveBase.get_tolerance(reltol, eltype(u))
35+
36+
jac_fn_wrapped = NonlinearSolveBase.construct_extension_jac(
37+
prob, alg, u, resid; alg.autodiff, can_handle_oop = Val(prob.u0 isa SArray)
38+
)
4239
jac_fn = if prob.u0 isa SArray
43-
@closure (u, p) -> _jac_fn(u)
40+
@closure (u, p) -> jac_fn_wrapped(u)
4441
else
45-
@closure (J, u, p) -> _jac_fn(J, u)
42+
@closure (J, u, p) -> jac_fn_wrapped(J, u)
4643
end
4744

48-
solver_kwargs = (; xtol = reltol, ftol = reltol, gtol = abstol, maxit = maxiters,
45+
solver_kwargs = (;
46+
xtol = reltol, ftol = reltol, gtol = abstol, maxit = maxiters,
4947
alg.factor, alg.factoraccept, alg.factorreject, alg.minscale,
50-
alg.maxscale, alg.factorupdate, alg.minfactor, alg.maxfactor)
48+
alg.maxscale, alg.factorupdate, alg.minfactor, alg.maxfactor
49+
)
5150

5251
if prob.u0 isa SArray
5352
res, fx, info, iter, nfev, njev = FastLM.lmsolve(
54-
f, jac_fn, prob.u0; solver_kwargs...)
53+
f, jac_fn, prob.u0; solver_kwargs...
54+
)
5555
LM, solver = nothing, nothing
5656
else
5757
J = prob.f.jac_prototype === nothing ? similar(u, length(resid), length(u)) :
5858
zero(prob.f.jac_prototype)
59-
solver = _fast_lm_solver(alg, u)
59+
60+
solver = if alg.linsolve === :cholesky
61+
FastLM.CholeskySolver(ArrayInterface.undefmatrix(u))
62+
elseif alg.linsolve === :qr
63+
FastLM.QRSolver(eltype(u), length(u))
64+
else
65+
throw(ArgumentError("Unknown FastLevenbergMarquardt Linear Solver: \
66+
$(Meta.quot(alg.linsolve))"))
67+
end
68+
6069
LM = FastLM.LMWorkspace(u, resid, J)
6170

6271
res, fx, info, iter, nfev, njev, LM, solver = FastLM.lmsolve!(
63-
f, jac_fn, LM; solver, solver_kwargs...)
72+
f, jac_fn, LM; solver, solver_kwargs...
73+
)
6474
end
6575

6676
stats = SciMLBase.NLStats(nfev, njev, -1, -1, iter)
6777
retcode = info == -1 ? ReturnCode.MaxIters : ReturnCode.Success
68-
return SciMLBase.build_solution(prob, alg, res, fx; retcode,
69-
original = (res, fx, info, iter, nfev, njev, LM, solver), stats)
78+
return SciMLBase.build_solution(
79+
prob, alg, res, fx; retcode,
80+
original = (res, fx, info, iter, nfev, njev, LM, solver), stats
81+
)
7082
end
7183

7284
end
+27-17
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,51 @@
11
module NonlinearSolveFixedPointAccelerationExt
22

3-
using NonlinearSolveBase: NonlinearSolveBase, get_tolerance
3+
using FixedPointAcceleration: FixedPointAcceleration, fixed_point
4+
5+
using NonlinearSolveBase: NonlinearSolveBase
46
using NonlinearSolve: NonlinearSolve, FixedPointAccelerationJL
57
using SciMLBase: SciMLBase, NonlinearProblem, ReturnCode
6-
using FixedPointAcceleration: FixedPointAcceleration, fixed_point
78

8-
function SciMLBase.__solve(prob::NonlinearProblem, alg::FixedPointAccelerationJL, args...;
9+
function SciMLBase.__solve(
10+
prob::NonlinearProblem, alg::FixedPointAccelerationJL, args...;
911
abstol = nothing, maxiters = 1000, alias_u0::Bool = false,
10-
show_trace::Val{PrintReports} = Val(false),
11-
termination_condition = nothing, kwargs...) where {PrintReports}
12-
NonlinearSolve.__test_termination_condition(
13-
termination_condition, :FixedPointAccelerationJL)
12+
show_trace::Val = Val(false), termination_condition = nothing, kwargs...
13+
)
14+
NonlinearSolveBase.assert_extension_supported_termination_condition(
15+
termination_condition, alg
16+
)
17+
18+
f, u0, resid = NonlinearSolveBase.construct_extension_function_wrapper(
19+
prob; alias_u0, make_fixed_point = Val(true), force_oop = Val(true)
20+
)
1421

15-
f, u0, resid = NonlinearSolve.__construct_extension_f(
16-
prob; alias_u0, make_fixed_point = Val(true), force_oop = Val(true))
17-
tol = get_tolerance(abstol, eltype(u0))
22+
tol = NonlinearSolveBase.get_tolerance(abstol, eltype(u0))
1823

19-
sol = fixed_point(f, u0; Algorithm = alg.algorithm, MaxIter = maxiters, MaxM = alg.m,
24+
sol = fixed_point(
25+
f, u0; Algorithm = alg.algorithm, MaxIter = maxiters, MaxM = alg.m,
2026
ConvergenceMetricThreshold = tol, ExtrapolationPeriod = alg.extrapolation_period,
21-
Dampening = alg.dampening, PrintReports, ReplaceInvalids = alg.replace_invalids,
22-
ConditionNumberThreshold = alg.condition_number_threshold, quiet_errors = true)
27+
Dampening = alg.dampening, PrintReports = show_trace isa Val{true},
28+
ReplaceInvalids = alg.replace_invalids,
29+
ConditionNumberThreshold = alg.condition_number_threshold, quiet_errors = true
30+
)
2331

2432
if sol.FixedPoint_ === missing
2533
u0 = prob.u0 isa Number ? u0[1] : u0
26-
resid = NonlinearSolve.evaluate_f(prob, u0)
34+
resid = NonlinearSolveBase.Utils.evaluate_f(prob, u0)
2735
res = u0
2836
converged = false
2937
else
3038
res = prob.u0 isa Number ? first(sol.FixedPoint_) :
3139
reshape(sol.FixedPoint_, size(prob.u0))
32-
resid = NonlinearSolve.evaluate_f(prob, res)
40+
resid = NonlinearSolveBase.Utils.evaluate_f(prob, res)
3341
converged = maximum(abs, resid) tol
3442
end
3543

36-
return SciMLBase.build_solution(prob, alg, res, resid; original = sol,
44+
return SciMLBase.build_solution(
45+
prob, alg, res, resid; original = sol,
3746
retcode = converged ? ReturnCode.Success : ReturnCode.Failure,
38-
stats = SciMLBase.NLStats(sol.Iterations_, 0, 0, 0, sol.Iterations_))
47+
stats = SciMLBase.NLStats(sol.Iterations_, 0, 0, 0, sol.Iterations_)
48+
)
3949
end
4050

4151
end
+45-54
Original file line numberDiff line numberDiff line change
@@ -1,79 +1,70 @@
11
module NonlinearSolveLeastSquaresOptimExt
22

3-
using ConcreteStructs: @concrete
43
using LeastSquaresOptim: LeastSquaresOptim
5-
using NonlinearSolveBase: NonlinearSolveBase, TraceMinimal, get_tolerance
4+
5+
using NonlinearSolveBase: NonlinearSolveBase, TraceMinimal
66
using NonlinearSolve: NonlinearSolve, LeastSquaresOptimJL
7-
using SciMLBase: SciMLBase, NonlinearLeastSquaresProblem, NonlinearProblem, ReturnCode
7+
using SciMLBase: SciMLBase, AbstractNonlinearProblem, ReturnCode
88

99
const LSO = LeastSquaresOptim
1010

11-
@inline function _lso_solver(::LeastSquaresOptimJL{alg, ls}) where {alg, ls}
12-
linsolve = ls === :qr ? LSO.QR() :
13-
(ls === :cholesky ? LSO.Cholesky() : (ls === :lsmr ? LSO.LSMR() : nothing))
14-
if alg === :lm
15-
return LSO.LevenbergMarquardt(linsolve)
16-
elseif alg === :dogleg
17-
return LSO.Dogleg(linsolve)
18-
else
19-
throw(ArgumentError("Unknown LeastSquaresOptim Algorithm: $alg"))
20-
end
21-
end
22-
23-
@concrete struct LeastSquaresOptimJLCache
24-
prob
25-
alg
26-
allocated_prob
27-
kwargs
28-
end
29-
30-
function Base.show(io::IO, cache::LeastSquaresOptimJLCache)
31-
print(io, "LeastSquaresOptimJLCache()")
32-
end
33-
34-
function SciMLBase.reinit!(cache::LeastSquaresOptimJLCache, args...; kwargs...)
35-
error("Reinitialization not supported for LeastSquaresOptimJL.")
36-
end
37-
38-
function SciMLBase.__init(prob::Union{NonlinearLeastSquaresProblem, NonlinearProblem},
39-
alg::LeastSquaresOptimJL, args...; alias_u0 = false, abstol = nothing,
40-
show_trace::Val{ShT} = Val(false), trace_level = TraceMinimal(),
41-
reltol = nothing, store_trace::Val{StT} = Val(false), maxiters = 1000,
42-
termination_condition = nothing, kwargs...) where {ShT, StT}
43-
NonlinearSolve.__test_termination_condition(termination_condition, :LeastSquaresOptim)
11+
function SciMLBase.__solve(
12+
prob::AbstractNonlinearProblem, alg::LeastSquaresOptimJL, args...;
13+
alias_u0 = false, abstol = nothing, reltol = nothing, maxiters = 1000,
14+
trace_level = TraceMinimal(), termination_condition = nothing,
15+
show_trace::Val = Val(false), store_trace::Val = Val(false), kwargs...
16+
)
17+
NonlinearSolveBase.assert_extension_supported_termination_condition(
18+
termination_condition, alg
19+
)
4420

45-
f!, u, resid = NonlinearSolve.__construct_extension_f(prob; alias_u0)
46-
abstol = get_tolerance(abstol, eltype(u))
47-
reltol = get_tolerance(reltol, eltype(u))
21+
f!, u, resid = NonlinearSolveBase.construct_extension_function_wrapper(prob; alias_u0)
22+
abstol = NonlinearSolveBase.get_tolerance(abstol, eltype(u))
23+
reltol = NonlinearSolveBase.get_tolerance(reltol, eltype(u))
4824

4925
if prob.f.jac === nothing && alg.autodiff isa Symbol
50-
lsoprob = LSO.LeastSquaresProblem(; x = u, f!, y = resid, alg.autodiff,
51-
J = prob.f.jac_prototype, output_length = length(resid))
26+
lsoprob = LSO.LeastSquaresProblem(;
27+
x = u, f!, y = resid, alg.autodiff, J = prob.f.jac_prototype,
28+
output_length = length(resid)
29+
)
5230
else
53-
g! = NonlinearSolve.__construct_extension_jac(prob, alg, u, resid; alg.autodiff)
31+
g! = NonlinearSolveBase.construct_extension_jac(prob, alg, u, resid; alg.autodiff)
5432
lsoprob = LSO.LeastSquaresProblem(;
5533
x = u, f!, y = resid, g!, J = prob.f.jac_prototype,
56-
output_length = length(resid))
34+
output_length = length(resid)
35+
)
5736
end
5837

59-
allocated_prob = LSO.LeastSquaresProblemAllocated(lsoprob, _lso_solver(alg))
38+
linsolve = alg.ls === :qr ? LSO.QR() :
39+
(alg.ls === :cholesky ? LSO.Cholesky() :
40+
(alg.ls === :lsmr ? LSO.LSMR() : nothing))
6041

61-
return LeastSquaresOptimJLCache(prob,
62-
alg,
63-
allocated_prob,
64-
(; x_tol = reltol, f_tol = abstol, g_tol = abstol, iterations = maxiters,
65-
show_trace = ShT, store_trace = StT, show_every = trace_level.print_frequency))
66-
end
42+
lso_solver = if alg.alg === :lm
43+
LSO.LevenbergMarquardt(linsolve)
44+
elseif alg.alg === :dogleg
45+
LSO.Dogleg(linsolve)
46+
else
47+
throw(ArgumentError("Unknown LeastSquaresOptim Algorithm: $(Meta.quot(alg.alg))"))
48+
end
49+
50+
allocated_prob = LSO.LeastSquaresProblemAllocated(lsoprob, lso_solver(alg))
51+
res = LSO.optimize!(
52+
allocated_prob;
53+
x_tol = reltol, f_tol = abstol, g_tol = abstol, iterations = maxiters,
54+
show_trace = show_trace isa Val{true}, store_trace = store_trace isa Val{true},
55+
show_every = trace_level.print_frequency
56+
)
6757

68-
function SciMLBase.solve!(cache::LeastSquaresOptimJLCache)
69-
res = LSO.optimize!(cache.allocated_prob; cache.kwargs...)
70-
maxiters = cache.kwargs[:iterations]
7158
retcode = res.x_converged || res.f_converged || res.g_converged ? ReturnCode.Success :
7259
(res.iterations maxiters ? ReturnCode.MaxIters :
7360
ReturnCode.ConvergenceFailure)
7461
stats = SciMLBase.NLStats(res.f_calls, res.g_calls, -1, -1, res.iterations)
62+
63+
f!(resid, res.minimizer)
64+
7565
return SciMLBase.build_solution(
76-
cache.prob, cache.alg, res.minimizer, res.ssr / 2; retcode, original = res, stats)
66+
prob, alg, res.minimizer, resid; retcode, original = res, stats
67+
)
7768
end
7869

7970
end

0 commit comments

Comments
 (0)