Skip to content

Commit 254c2fb

Browse files
committed
refactor(SimpleNonlinearSolve): reuse more code from NLB
1 parent 99d3216 commit 254c2fb

21 files changed

+284
-251
lines changed

lib/BracketingNonlinearSolve/Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ ConcreteStructs = "0.2.3"
2424
ExplicitImports = "1.10.1"
2525
ForwardDiff = "0.10.36"
2626
InteractiveUtils = "<0.0.1, 1"
27-
NonlinearSolveBase = "1"
27+
NonlinearSolveBase = "1.1"
2828
PrecompileTools = "1.2"
2929
Reexport = "1.2"
3030
SciMLBase = "2.50"

lib/NonlinearSolveBase/ext/NonlinearSolveBaseBandedMatricesExt.jl

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module NonlinearSolveBaseBandedMatricesExt
22

33
using BandedMatrices: BandedMatrix
44
using LinearAlgebra: Diagonal
5+
56
using NonlinearSolveBase: NonlinearSolveBase, Utils
67

78
# This is used if we vcat a Banded Jacobian with a Diagonal Matrix in Levenberg

lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl

+12-6
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ Utils.value(x::AbstractArray{<:Dual}) = Utils.value.(x)
2525

2626
function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
2727
prob::Union{IntervalNonlinearProblem, NonlinearProblem, ImmutableNonlinearProblem},
28-
alg, args...; kwargs...)
28+
alg, args...; kwargs...
29+
)
2930
p = Utils.value(prob.p)
3031
if prob isa IntervalNonlinearProblem
3132
tspan = Utils.value.(prob.tspan)
@@ -55,7 +56,8 @@ function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
5556
end
5657

5758
function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
58-
prob::NonlinearLeastSquaresProblem, alg, args...; kwargs...)
59+
prob::NonlinearLeastSquaresProblem, alg, args...; kwargs...
60+
)
5961
p = Utils.value(prob.p)
6062
newprob = remake(prob; p, u0 = Utils.value(prob.u0))
6163
sol = solve(newprob, alg, args...; kwargs...)
@@ -168,13 +170,17 @@ function NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, f::F, u, p) where {F}
168170
return ForwardDiff.jacobian(Base.Fix2(f, p), u)
169171
end
170172

171-
function NonlinearSolveBase.nonlinearsolve_dual_solution(u::Number, partials,
172-
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
173+
function NonlinearSolveBase.nonlinearsolve_dual_solution(
174+
u::Number, partials,
175+
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}
176+
) where {T, V, P}
173177
return Dual{T, V, P}(u, partials)
174178
end
175179

176-
function NonlinearSolveBase.nonlinearsolve_dual_solution(u::AbstractArray, partials,
177-
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
180+
function NonlinearSolveBase.nonlinearsolve_dual_solution(
181+
u::AbstractArray, partials,
182+
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}
183+
) where {T, V, P}
178184
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, Utils.restructure(u, partials)))
179185
end
180186

lib/NonlinearSolveBase/ext/NonlinearSolveBaseLineSearchExt.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
module NonlinearSolveBaseLineSearchExt
22

33
using LineSearch: LineSearch, AbstractLineSearchCache
4-
using NonlinearSolveBase: NonlinearSolveBase, InternalAPI
54
using SciMLBase: SciMLBase
65

6+
using NonlinearSolveBase: NonlinearSolveBase, InternalAPI
7+
78
function NonlinearSolveBase.callback_into_cache!(
89
topcache, cache::AbstractLineSearchCache, args...
910
)

lib/NonlinearSolveBase/ext/NonlinearSolveBaseLinearSolveExt.jl

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
module NonlinearSolveBaseLinearSolveExt
22

33
using ArrayInterface: ArrayInterface
4+
45
using CommonSolve: CommonSolve, init, solve!
5-
using LinearAlgebra: ColumnNorm
66
using LinearSolve: LinearSolve, QRFactorization, SciMLLinearSolveAlgorithm
7-
using NonlinearSolveBase: NonlinearSolveBase, LinearSolveJLCache, LinearSolveResult, Utils
87
using SciMLBase: ReturnCode, LinearProblem
98

9+
using LinearAlgebra: ColumnNorm
10+
11+
using NonlinearSolveBase: NonlinearSolveBase, LinearSolveJLCache, LinearSolveResult, Utils
12+
1013
function (cache::LinearSolveJLCache)(;
1114
A = nothing, b = nothing, linu = nothing, du = nothing, p = nothing,
12-
cachedata = nothing, reuse_A_if_factorization = false, verbose = true, kwargs...)
15+
cachedata = nothing, reuse_A_if_factorization = false, verbose = true, kwargs...
16+
)
1317
cache.stats.nsolve += 1
1418

1519
update_A!(cache, A, reuse_A_if_factorization)

lib/NonlinearSolveBase/ext/NonlinearSolveBaseSparseArraysExt.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
module NonlinearSolveBaseSparseArraysExt
22

3-
using NonlinearSolveBase: NonlinearSolveBase, Utils
43
using SparseArrays: AbstractSparseMatrix, AbstractSparseMatrixCSC, nonzeros, sparse
54

5+
using NonlinearSolveBase: NonlinearSolveBase, Utils
6+
67
function NonlinearSolveBase.NAN_CHECK(x::AbstractSparseMatrixCSC)
78
return any(NonlinearSolveBase.NAN_CHECK, nonzeros(x))
89
end

lib/NonlinearSolveBase/ext/NonlinearSolveBaseSparseMatrixColoringsExt.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
module NonlinearSolveBaseSparseMatrixColoringsExt
22

33
using ADTypes: ADTypes, AbstractADType
4-
using NonlinearSolveBase: NonlinearSolveBase, Utils
54
using SciMLBase: SciMLBase, NonlinearFunction
5+
66
using SparseMatrixColorings: ConstantColoringAlgorithm, GreedyColoringAlgorithm,
77
LargestFirst
88

9+
using NonlinearSolveBase: NonlinearSolveBase, Utils
10+
911
Utils.is_extension_loaded(::Val{:SparseMatrixColorings}) = true
1012

1113
function NonlinearSolveBase.select_fastest_coloring_algorithm(

lib/NonlinearSolveBase/src/utils.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,9 @@ maybe_unaliased(x::AbstractSciMLOperator, ::Bool) = x
138138
can_setindex(x) = ArrayInterface.can_setindex(x)
139139
can_setindex(::Number) = false
140140

141-
evaluate_f!!(prob::AbstractNonlinearProblem, fu, u, p) = evaluate_f!!(prob.f, fu, u, p)
141+
function evaluate_f!!(prob::AbstractNonlinearProblem, fu, u, p = prob.p)
142+
return evaluate_f!!(prob.f, fu, u, p)
143+
end
142144
function evaluate_f!!(f::NonlinearFunction, fu, u, p)
143145
if SciMLBase.isinplace(f)
144146
f(fu, u, p)

lib/SimpleNonlinearSolve/Project.toml

+4-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ version = "2.0.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
8-
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
98
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
109
BracketingNonlinearSolve = "70df07ce-3d50-431d-a3e7-ca6ddb60ac1e"
1110
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
@@ -21,6 +20,7 @@ NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
2120
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2221
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2322
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
23+
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2424
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2525

2626
[weakdeps]
@@ -37,10 +37,9 @@ SimpleNonlinearSolveTrackerExt = "Tracker"
3737

3838
[compat]
3939
ADTypes = "1.2"
40-
Accessors = "0.1"
4140
Aqua = "0.8.7"
4241
ArrayInterface = "7.16"
43-
BracketingNonlinearSolve = "1"
42+
BracketingNonlinearSolve = "1.1"
4443
ChainRulesCore = "1.24"
4544
CommonSolve = "0.2.4"
4645
ConcreteStructs = "0.2.3"
@@ -56,14 +55,15 @@ LineSearch = "0.1.3"
5655
LinearAlgebra = "1.10"
5756
MaybeInplace = "0.1.4"
5857
NonlinearProblemLibrary = "0.1.2"
59-
NonlinearSolveBase = "1"
58+
NonlinearSolveBase = "1.1"
6059
Pkg = "1.10"
6160
PolyesterForwardDiff = "0.1"
6261
PrecompileTools = "1.2"
6362
Random = "1.10"
6463
Reexport = "1.2"
6564
ReverseDiff = "1.15"
6665
SciMLBase = "2.50"
66+
Setfield = "1.1.1"
6767
StaticArrays = "1.9"
6868
StaticArraysCore = "1.4.3"
6969
Test = "1.10"

lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl

+9-4
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
11
module SimpleNonlinearSolveChainRulesCoreExt
22

33
using ChainRulesCore: ChainRulesCore, NoTangent
4+
45
using NonlinearSolveBase: ImmutableNonlinearProblem
56
using SciMLBase: ChainRulesOriginator, NonlinearLeastSquaresProblem
67

78
using SimpleNonlinearSolve: SimpleNonlinearSolve, simplenonlinearsolve_solve_up,
89
solve_adjoint
910

10-
function ChainRulesCore.rrule(::typeof(simplenonlinearsolve_solve_up),
11+
function ChainRulesCore.rrule(
12+
::typeof(simplenonlinearsolve_solve_up),
1113
prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem},
12-
sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...)
14+
sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...
15+
)
1316
out, ∇internal = solve_adjoint(
14-
prob, sensealg, u0, p, ChainRulesOriginator(), alg, args...; kwargs...)
17+
prob, sensealg, u0, p, ChainRulesOriginator(), alg, args...; kwargs...
18+
)
1519
function ∇simplenonlinearsolve_solve_up(Δ)
1620
∂f, ∂prob, ∂sensealg, ∂u0, ∂p, _, ∂args... = ∇internal(Δ)
1721
return (
18-
∂f, ∂prob, ∂sensealg, ∂u0, NoTangent(), ∂p, NoTangent(), NoTangent(), ∂args...)
22+
∂f, ∂prob, ∂sensealg, ∂u0, NoTangent(), ∂p, NoTangent(), NoTangent(), ∂args...
23+
)
1924
end
2025
return out, ∇simplenonlinearsolve_solve_up
2126
end

lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
module SimpleNonlinearSolveReverseDiffExt
22

3-
using ArrayInterface: ArrayInterface
43
using NonlinearSolveBase: ImmutableNonlinearProblem
5-
using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal
64
using SciMLBase: ReverseDiffOriginator, NonlinearLeastSquaresProblem, remake
75

6+
using ArrayInterface: ArrayInterface
7+
using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal
8+
89
using SimpleNonlinearSolve: SimpleNonlinearSolve, solve_adjoint
910
import SimpleNonlinearSolve: simplenonlinearsolve_solve_up
1011

lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
module SimpleNonlinearSolveTrackerExt
22

3-
using ArrayInterface: ArrayInterface
43
using NonlinearSolveBase: ImmutableNonlinearProblem
54
using SciMLBase: TrackerOriginator, NonlinearLeastSquaresProblem, remake
5+
6+
using ArrayInterface: ArrayInterface
67
using Tracker: Tracker, TrackedArray, TrackedReal
78

89
using SimpleNonlinearSolve: SimpleNonlinearSolve, solve_adjoint

lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl

+55-37
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,46 @@
11
module SimpleNonlinearSolve
22

3-
using Accessors: @reset
4-
using BracketingNonlinearSolve: BracketingNonlinearSolve
5-
using CommonSolve: CommonSolve, solve, init, solve!
63
using ConcreteStructs: @concrete
74
using FastClosures: @closure
8-
using LineSearch: LiFukushimaLineSearch
9-
using LinearAlgebra: LinearAlgebra, dot
10-
using MaybeInplace: @bb, setindex_trait, CannotSetindex, CanSetindex
115
using PrecompileTools: @compile_workload, @setup_workload
126
using Reexport: @reexport
13-
using SciMLBase: SciMLBase, AbstractNonlinearAlgorithm, NonlinearFunction, NonlinearProblem,
14-
NonlinearLeastSquaresProblem, IntervalNonlinearProblem, ReturnCode, remake
7+
using Setfield: @set!
8+
9+
using BracketingNonlinearSolve: BracketingNonlinearSolve
10+
using CommonSolve: CommonSolve, solve, init, solve!
11+
using LineSearch: LiFukushimaLineSearch
12+
using MaybeInplace: @bb
13+
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, L2_NORM,
14+
nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution,
15+
AbstractNonlinearSolveAlgorithm
16+
using SciMLBase: SciMLBase, NonlinearFunction, NonlinearProblem,
17+
NonlinearLeastSquaresProblem, ReturnCode, remake
18+
19+
using LinearAlgebra: LinearAlgebra, dot
20+
1521
using StaticArraysCore: StaticArray, SArray, SVector, MArray
1622

1723
# AD Dependencies
1824
using ADTypes: ADTypes, AutoForwardDiff
1925
using DifferentiationInterface: DifferentiationInterface
2026
using FiniteDiff: FiniteDiff
21-
using ForwardDiff: ForwardDiff
22-
23-
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, L2_NORM,
24-
nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution
27+
using ForwardDiff: ForwardDiff, Dual
2528

2629
const DI = DifferentiationInterface
2730

28-
abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end
31+
const DualNonlinearProblem = NonlinearProblem{
32+
<:Union{Number, <:AbstractArray}, iip,
33+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
34+
} where {iip, T, V, P}
35+
36+
const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
37+
<:Union{Number, <:AbstractArray}, iip,
38+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
39+
} where {iip, T, V, P}
2940

30-
const safe_similar = NonlinearSolveBase.Utils.safe_similar
41+
abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearSolveAlgorithm end
42+
43+
const NLBUtils = NonlinearSolveBase.Utils
3144

3245
is_extension_loaded(::Val) = false
3346

@@ -42,61 +55,66 @@ include("raphson.jl")
4255
include("trust_region.jl")
4356

4457
# By Pass the highlevel checks for NonlinearProblem for Simple Algorithms
45-
function CommonSolve.solve(prob::NonlinearProblem,
46-
alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...)
58+
function CommonSolve.solve(
59+
prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm, args...;
60+
kwargs...
61+
)
4762
prob = convert(ImmutableNonlinearProblem, prob)
4863
return solve(prob, alg, args...; kwargs...)
4964
end
5065

5166
function CommonSolve.solve(
52-
prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip,
53-
<:Union{
54-
<:ForwardDiff.Dual{T, V, P}, <:AbstractArray{<:ForwardDiff.Dual{T, V, P}}}},
55-
alg::AbstractSimpleNonlinearSolveAlgorithm,
56-
args...;
57-
kwargs...) where {T, V, P, iip}
67+
prob::DualNonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
68+
args...; kwargs...
69+
)
5870
if hasfield(typeof(alg), :autodiff) && alg.autodiff === nothing
59-
@reset alg.autodiff = AutoForwardDiff()
71+
@set! alg.autodiff = AutoForwardDiff()
6072
end
6173
prob = convert(ImmutableNonlinearProblem, prob)
6274
sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...)
6375
dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p)
6476
return SciMLBase.build_solution(
65-
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
77+
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
78+
)
6679
end
6780

6881
function CommonSolve.solve(
69-
prob::NonlinearLeastSquaresProblem{<:Union{Number, <:AbstractArray}, iip,
70-
<:Union{
71-
<:ForwardDiff.Dual{T, V, P}, <:AbstractArray{<:ForwardDiff.Dual{T, V, P}}}},
72-
alg::AbstractSimpleNonlinearSolveAlgorithm,
73-
args...;
74-
kwargs...) where {T, V, P, iip}
82+
prob::DualNonlinearLeastSquaresProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
83+
args...; kwargs...
84+
)
7585
if hasfield(typeof(alg), :autodiff) && alg.autodiff === nothing
76-
@reset alg.autodiff = AutoForwardDiff()
86+
@set! alg.autodiff = AutoForwardDiff()
7787
end
7888
sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...)
7989
dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p)
8090
return SciMLBase.build_solution(
81-
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
91+
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
92+
)
8293
end
8394

8495
function CommonSolve.solve(
8596
prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem},
8697
alg::AbstractSimpleNonlinearSolveAlgorithm,
87-
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
98+
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...
99+
)
88100
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
89101
sensealg = prob.kwargs[:sensealg]
90102
end
91103
new_u0 = u0 !== nothing ? u0 : prob.u0
92104
new_p = p !== nothing ? p : prob.p
93-
return simplenonlinearsolve_solve_up(prob, sensealg, new_u0, u0 === nothing, new_p,
94-
p === nothing, alg, args...; prob.kwargs..., kwargs...)
105+
return simplenonlinearsolve_solve_up(
106+
prob, sensealg,
107+
new_u0, u0 === nothing,
108+
new_p, p === nothing,
109+
alg, args...;
110+
prob.kwargs..., kwargs...
111+
)
95112
end
96113

97114
function simplenonlinearsolve_solve_up(
98115
prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0,
99-
u0_changed, p, p_changed, alg, args...; kwargs...)
116+
u0_changed, p, p_changed, alg, args...; kwargs...
117+
)
100118
(u0_changed || p_changed) && (prob = remake(prob; u0, p))
101119
return SciMLBase.__solve(prob, alg, args...; kwargs...)
102120
end
@@ -131,7 +149,7 @@ function solve_adjoint_internal end
131149

132150
@compile_workload begin
133151
for prob in (prob_scalar, prob_iip, prob_oop), alg in algs
134-
CommonSolve.solve(prob, alg; abstol = 1e-2)
152+
CommonSolve.solve(prob, alg; abstol = 1e-2, verbose = false)
135153
end
136154
end
137155
end

0 commit comments

Comments
 (0)