1
1
module SimpleNonlinearSolve
2
2
3
- using Accessors: @reset
4
- using BracketingNonlinearSolve: BracketingNonlinearSolve
5
- using CommonSolve: CommonSolve, solve, init, solve!
6
3
using ConcreteStructs: @concrete
7
4
using FastClosures: @closure
8
- using LineSearch: LiFukushimaLineSearch
9
- using LinearAlgebra: LinearAlgebra, dot
10
- using MaybeInplace: @bb , setindex_trait, CannotSetindex, CanSetindex
11
5
using PrecompileTools: @compile_workload , @setup_workload
12
6
using Reexport: @reexport
13
- using SciMLBase: SciMLBase, AbstractNonlinearAlgorithm, NonlinearFunction, NonlinearProblem,
14
- NonlinearLeastSquaresProblem, IntervalNonlinearProblem, ReturnCode, remake
7
+ using Setfield: @set!
8
+
9
+ using BracketingNonlinearSolve: BracketingNonlinearSolve
10
+ using CommonSolve: CommonSolve, solve, init, solve!
11
+ using LineSearch: LiFukushimaLineSearch
12
+ using MaybeInplace: @bb
13
+ using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, L2_NORM,
14
+ nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution,
15
+ AbstractNonlinearSolveAlgorithm
16
+ using SciMLBase: SciMLBase, NonlinearFunction, NonlinearProblem,
17
+ NonlinearLeastSquaresProblem, ReturnCode, remake
18
+
19
+ using LinearAlgebra: LinearAlgebra, dot
20
+
15
21
using StaticArraysCore: StaticArray, SArray, SVector, MArray
16
22
17
23
# AD Dependencies
18
24
using ADTypes: ADTypes, AutoForwardDiff
19
25
using DifferentiationInterface: DifferentiationInterface
20
26
using FiniteDiff: FiniteDiff
21
- using ForwardDiff: ForwardDiff
22
-
23
- using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, L2_NORM,
24
- nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution
27
+ using ForwardDiff: ForwardDiff, Dual
25
28
26
29
const DI = DifferentiationInterface
27
30
28
- abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end
31
+ const DualNonlinearProblem = NonlinearProblem{
32
+ <: Union{Number, <:AbstractArray} , iip,
33
+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
34
+ } where {iip, T, V, P}
35
+
36
+ const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
37
+ <: Union{Number, <:AbstractArray} , iip,
38
+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
39
+ } where {iip, T, V, P}
29
40
30
- const safe_similar = NonlinearSolveBase. Utils. safe_similar
41
+ abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearSolveAlgorithm end
42
+
43
+ const NLBUtils = NonlinearSolveBase. Utils
31
44
32
45
is_extension_loaded (:: Val ) = false
33
46
@@ -42,61 +55,66 @@ include("raphson.jl")
42
55
include (" trust_region.jl" )
43
56
44
57
# By Pass the highlevel checks for NonlinearProblem for Simple Algorithms
45
- function CommonSolve. solve (prob:: NonlinearProblem ,
46
- alg:: AbstractSimpleNonlinearSolveAlgorithm , args... ; kwargs... )
58
+ function CommonSolve. solve (
59
+ prob:: NonlinearProblem , alg:: AbstractSimpleNonlinearSolveAlgorithm , args... ;
60
+ kwargs...
61
+ )
47
62
prob = convert (ImmutableNonlinearProblem, prob)
48
63
return solve (prob, alg, args... ; kwargs... )
49
64
end
50
65
51
66
function CommonSolve. solve (
52
- prob:: NonlinearProblem {<: Union{Number, <:AbstractArray} , iip,
53
- <: Union {
54
- <: ForwardDiff.Dual{T, V, P} , <: AbstractArray{<:ForwardDiff.Dual{T, V, P}} }},
55
- alg:: AbstractSimpleNonlinearSolveAlgorithm ,
56
- args... ;
57
- kwargs... ) where {T, V, P, iip}
67
+ prob:: DualNonlinearProblem , alg:: AbstractSimpleNonlinearSolveAlgorithm ,
68
+ args... ; kwargs...
69
+ )
58
70
if hasfield (typeof (alg), :autodiff ) && alg. autodiff === nothing
59
- @reset alg. autodiff = AutoForwardDiff ()
71
+ @set! alg. autodiff = AutoForwardDiff ()
60
72
end
61
73
prob = convert (ImmutableNonlinearProblem, prob)
62
74
sol, partials = nonlinearsolve_forwarddiff_solve (prob, alg, args... ; kwargs... )
63
75
dual_soln = nonlinearsolve_dual_solution (sol. u, partials, prob. p)
64
76
return SciMLBase. build_solution (
65
- prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original)
77
+ prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original
78
+ )
66
79
end
67
80
68
81
function CommonSolve. solve (
69
- prob:: NonlinearLeastSquaresProblem {<: Union{Number, <:AbstractArray} , iip,
70
- <: Union {
71
- <: ForwardDiff.Dual{T, V, P} , <: AbstractArray{<:ForwardDiff.Dual{T, V, P}} }},
72
- alg:: AbstractSimpleNonlinearSolveAlgorithm ,
73
- args... ;
74
- kwargs... ) where {T, V, P, iip}
82
+ prob:: DualNonlinearLeastSquaresProblem , alg:: AbstractSimpleNonlinearSolveAlgorithm ,
83
+ args... ; kwargs...
84
+ )
75
85
if hasfield (typeof (alg), :autodiff ) && alg. autodiff === nothing
76
- @reset alg. autodiff = AutoForwardDiff ()
86
+ @set! alg. autodiff = AutoForwardDiff ()
77
87
end
78
88
sol, partials = nonlinearsolve_forwarddiff_solve (prob, alg, args... ; kwargs... )
79
89
dual_soln = nonlinearsolve_dual_solution (sol. u, partials, prob. p)
80
90
return SciMLBase. build_solution (
81
- prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original)
91
+ prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original
92
+ )
82
93
end
83
94
84
95
function CommonSolve. solve (
85
96
prob:: Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem} ,
86
97
alg:: AbstractSimpleNonlinearSolveAlgorithm ,
87
- args... ; sensealg = nothing , u0 = nothing , p = nothing , kwargs... )
98
+ args... ; sensealg = nothing , u0 = nothing , p = nothing , kwargs...
99
+ )
88
100
if sensealg === nothing && haskey (prob. kwargs, :sensealg )
89
101
sensealg = prob. kwargs[:sensealg ]
90
102
end
91
103
new_u0 = u0 != = nothing ? u0 : prob. u0
92
104
new_p = p != = nothing ? p : prob. p
93
- return simplenonlinearsolve_solve_up (prob, sensealg, new_u0, u0 === nothing , new_p,
94
- p === nothing , alg, args... ; prob. kwargs... , kwargs... )
105
+ return simplenonlinearsolve_solve_up (
106
+ prob, sensealg,
107
+ new_u0, u0 === nothing ,
108
+ new_p, p === nothing ,
109
+ alg, args... ;
110
+ prob. kwargs... , kwargs...
111
+ )
95
112
end
96
113
97
114
function simplenonlinearsolve_solve_up (
98
115
prob:: Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem} , sensealg, u0,
99
- u0_changed, p, p_changed, alg, args... ; kwargs... )
116
+ u0_changed, p, p_changed, alg, args... ; kwargs...
117
+ )
100
118
(u0_changed || p_changed) && (prob = remake (prob; u0, p))
101
119
return SciMLBase. __solve (prob, alg, args... ; kwargs... )
102
120
end
@@ -131,7 +149,7 @@ function solve_adjoint_internal end
131
149
132
150
@compile_workload begin
133
151
for prob in (prob_scalar, prob_iip, prob_oop), alg in algs
134
- CommonSolve. solve (prob, alg; abstol = 1e-2 )
152
+ CommonSolve. solve (prob, alg; abstol = 1e-2 , verbose = false )
135
153
end
136
154
end
137
155
end
0 commit comments