|
1 | 1 | module NonlinearSolveFastLevenbergMarquardtExt
|
2 | 2 |
|
3 |
| -using ArrayInterface: ArrayInterface |
4 | 3 | using FastClosures: @closure
|
| 4 | + |
| 5 | +using ArrayInterface: ArrayInterface |
5 | 6 | using FastLevenbergMarquardt: FastLevenbergMarquardt
|
6 |
| -using NonlinearSolveBase: NonlinearSolveBase, get_tolerance |
7 |
| -using NonlinearSolve: NonlinearSolve, FastLevenbergMarquardtJL |
8 |
| -using SciMLBase: SciMLBase, NonlinearLeastSquaresProblem, NonlinearProblem, ReturnCode |
9 | 7 | using StaticArraysCore: SArray
|
10 | 8 |
|
11 |
| -const FastLM = FastLevenbergMarquardt |
| 9 | +using NonlinearSolveBase: NonlinearSolveBase |
| 10 | +using NonlinearSolve: NonlinearSolve, FastLevenbergMarquardtJL |
| 11 | +using SciMLBase: SciMLBase, AbstractNonlinearProblem, ReturnCode |
12 | 12 |
|
13 |
| -@inline function _fast_lm_solver(::FastLevenbergMarquardtJL{linsolve}, x) where {linsolve} |
14 |
| - if linsolve === :cholesky |
15 |
| - return FastLM.CholeskySolver(ArrayInterface.undefmatrix(x)) |
16 |
| - elseif linsolve === :qr |
17 |
| - return FastLM.QRSolver(eltype(x), length(x)) |
18 |
| - else |
19 |
| - throw(ArgumentError("Unknown FastLevenbergMarquardt Linear Solver: $linsolve")) |
20 |
| - end |
21 |
| -end |
22 |
| -@inline _fast_lm_solver(::FastLevenbergMarquardtJL{linsolve}, ::SArray) where {linsolve} = linsolve |
| 13 | +const FastLM = FastLevenbergMarquardt |
23 | 14 |
|
24 |
| -function SciMLBase.__solve(prob::Union{NonlinearLeastSquaresProblem, NonlinearProblem}, |
25 |
| - alg::FastLevenbergMarquardtJL, args...; alias_u0 = false, abstol = nothing, |
26 |
| - reltol = nothing, maxiters = 1000, termination_condition = nothing, kwargs...) |
27 |
| - NonlinearSolve.__test_termination_condition( |
28 |
| - termination_condition, :FastLevenbergMarquardt) |
| 15 | +function SciMLBase.__solve( |
| 16 | + prob::AbstractNonlinearProblem, alg::FastLevenbergMarquardtJL, args...; |
| 17 | + alias_u0 = false, abstol = nothing, reltol = nothing, maxiters = 1000, |
| 18 | + termination_condition = nothing, kwargs... |
| 19 | +) |
| 20 | + NonlinearSolveBase.assert_extension_supported_termination_condition( |
| 21 | + termination_condition, alg |
| 22 | + ) |
29 | 23 |
|
30 |
| - fn, u, resid = NonlinearSolve.__construct_extension_f( |
31 |
| - prob; alias_u0, can_handle_oop = Val(prob.u0 isa SArray)) |
| 24 | + f_wrapped, u, resid = NonlinearSolveBase.construct_extension_function_wrapper( |
| 25 | + prob; alias_u0, can_handle_oop = Val(prob.u0 isa SArray) |
| 26 | + ) |
32 | 27 | f = if prob.u0 isa SArray
|
33 |
| - @closure (u, p) -> fn(u) |
| 28 | + @closure (u, p) -> f_wrapped(u) |
34 | 29 | else
|
35 |
| - @closure (du, u, p) -> fn(du, u) |
| 30 | + @closure (du, u, p) -> f_wrapped(du, u) |
36 | 31 | end
|
37 |
| - abstol = get_tolerance(abstol, eltype(u)) |
38 |
| - reltol = get_tolerance(reltol, eltype(u)) |
39 | 32 |
|
40 |
| - _jac_fn = NonlinearSolve.__construct_extension_jac( |
41 |
| - prob, alg, u, resid; alg.autodiff, can_handle_oop = Val(prob.u0 isa SArray)) |
| 33 | + abstol = NonlinearSolveBase.get_tolerance(abstol, eltype(u)) |
| 34 | + reltol = NonlinearSolveBase.get_tolerance(reltol, eltype(u)) |
| 35 | + |
| 36 | + jac_fn_wrapped = NonlinearSolveBase.construct_extension_jac( |
| 37 | + prob, alg, u, resid; alg.autodiff, can_handle_oop = Val(prob.u0 isa SArray) |
| 38 | + ) |
42 | 39 | jac_fn = if prob.u0 isa SArray
|
43 |
| - @closure (u, p) -> _jac_fn(u) |
| 40 | + @closure (u, p) -> jac_fn_wrapped(u) |
44 | 41 | else
|
45 |
| - @closure (J, u, p) -> _jac_fn(J, u) |
| 42 | + @closure (J, u, p) -> jac_fn_wrapped(J, u) |
46 | 43 | end
|
47 | 44 |
|
48 |
| - solver_kwargs = (; xtol = reltol, ftol = reltol, gtol = abstol, maxit = maxiters, |
| 45 | + solver_kwargs = (; |
| 46 | + xtol = reltol, ftol = reltol, gtol = abstol, maxit = maxiters, |
49 | 47 | alg.factor, alg.factoraccept, alg.factorreject, alg.minscale,
|
50 |
| - alg.maxscale, alg.factorupdate, alg.minfactor, alg.maxfactor) |
| 48 | + alg.maxscale, alg.factorupdate, alg.minfactor, alg.maxfactor |
| 49 | + ) |
51 | 50 |
|
52 | 51 | if prob.u0 isa SArray
|
53 | 52 | res, fx, info, iter, nfev, njev = FastLM.lmsolve(
|
54 |
| - f, jac_fn, prob.u0; solver_kwargs...) |
| 53 | + f, jac_fn, prob.u0; solver_kwargs... |
| 54 | + ) |
55 | 55 | LM, solver = nothing, nothing
|
56 | 56 | else
|
57 | 57 | J = prob.f.jac_prototype === nothing ? similar(u, length(resid), length(u)) :
|
58 | 58 | zero(prob.f.jac_prototype)
|
59 |
| - solver = _fast_lm_solver(alg, u) |
| 59 | + |
| 60 | + solver = if alg.linsolve === :cholesky |
| 61 | + FastLM.CholeskySolver(ArrayInterface.undefmatrix(u)) |
| 62 | + elseif alg.linsolve === :qr |
| 63 | + FastLM.QRSolver(eltype(u), length(u)) |
| 64 | + else |
| 65 | + throw(ArgumentError("Unknown FastLevenbergMarquardt Linear Solver: \ |
| 66 | + $(Meta.quot(alg.linsolve))")) |
| 67 | + end |
| 68 | + |
60 | 69 | LM = FastLM.LMWorkspace(u, resid, J)
|
61 | 70 |
|
62 | 71 | res, fx, info, iter, nfev, njev, LM, solver = FastLM.lmsolve!(
|
63 |
| - f, jac_fn, LM; solver, solver_kwargs...) |
| 72 | + f, jac_fn, LM; solver, solver_kwargs... |
| 73 | + ) |
64 | 74 | end
|
65 | 75 |
|
66 | 76 | stats = SciMLBase.NLStats(nfev, njev, -1, -1, iter)
|
67 | 77 | retcode = info == -1 ? ReturnCode.MaxIters : ReturnCode.Success
|
68 |
| - return SciMLBase.build_solution(prob, alg, res, fx; retcode, |
69 |
| - original = (res, fx, info, iter, nfev, njev, LM, solver), stats) |
| 78 | + return SciMLBase.build_solution( |
| 79 | + prob, alg, res, fx; retcode, |
| 80 | + original = (res, fx, info, iter, nfev, njev, LM, solver), stats |
| 81 | + ) |
70 | 82 | end
|
71 | 83 |
|
72 | 84 | end
|
0 commit comments