Skip to content

Commit 3075a08

Browse files
committed
refactor: delete more code
1 parent 8a22202 commit 3075a08

File tree

18 files changed

+359
-512
lines changed

18 files changed

+359
-512
lines changed

Project.toml

+1-2
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1313
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
1414
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1515
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
16-
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
1716
LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b"
1817
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1918
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
2019
MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
2120
NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
21+
NonlinearSolveFirstOrder = "5959db7a-ea39-4486-b5fe-2dd0bf03d60d"
2222
NonlinearSolveQuasiNewton = "9a2c21bd-3a47-402d-9113-8faf9a0ee114"
2323
NonlinearSolveSpectralMethods = "26075421-4e9a-44e1-8bd1-420ed7ad02b2"
2424
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
@@ -81,7 +81,6 @@ FixedPointAcceleration = "0.3"
8181
ForwardDiff = "0.10.36"
8282
Hwloc = "3"
8383
InteractiveUtils = "<0.0.1, 1"
84-
LazyArrays = "1.8.2, 2"
8584
LeastSquaresOptim = "0.8.5"
8685
LineSearch = "0.1.4"
8786
LineSearches = "7.3"

lib/NonlinearSolveBase/src/solve.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@ function CommonSolve.solve!(cache::AbstractNonlinearSolveCache)
3232
end
3333

3434
"""
35-
step!(cache::AbstractNonlinearSolveCache;
36-
recompute_jacobian::Union{Nothing, Bool} = nothing)
35+
step!(
36+
cache::AbstractNonlinearSolveCache;
37+
recompute_jacobian::Union{Nothing, Bool} = nothing
38+
)
3739
3840
Performs one step of the nonlinear solver.
3941

lib/NonlinearSolveBase/src/utils.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ function clean_sprint_struct(x)
248248
name = nameof(typeof(x))
249249
for field in fieldnames(typeof(x))
250250
val = getfield(x, field)
251-
if field === :name
251+
if field === :name && val isa Symbol && val !== :unknown
252252
name = val
253253
continue
254254
end
@@ -268,7 +268,7 @@ function clean_sprint_struct(x, indent::Int)
268268
name = nameof(typeof(x))
269269
for field in fieldnames(typeof(x))
270270
val = getfield(x, field)
271-
if field === :name
271+
if field === :name && val isa Symbol && val !== :unknown
272272
name = val
273273
continue
274274
end
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,40 @@
11
module NonlinearSolveFirstOrder
22

3+
using Reexport: @reexport
4+
using PrecompileTools: @compile_workload, @setup_workload
5+
6+
using ArrayInterface: ArrayInterface
7+
using CommonSolve: CommonSolve
8+
using ConcreteStructs: @concrete
9+
using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches
10+
using LinearAlgebra: LinearAlgebra, Diagonal, dot, inv, diag
11+
using LinearSolve: LinearSolve # Trigger Linear Solve extension in NonlinearSolveBase
12+
using MaybeInplace: @bb
13+
using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm,
14+
AbstractNonlinearSolveCache, AbstractResetCondition,
15+
AbstractResetConditionCache, AbstractApproximateJacobianStructure,
16+
AbstractJacobianCache, AbstractJacobianInitialization,
17+
AbstractApproximateJacobianUpdateRule, AbstractDescentDirection,
18+
AbstractApproximateJacobianUpdateRuleCache,
19+
Utils, InternalAPI, get_timer_output, @static_timeit,
20+
update_trace!, L2_NORM, NewtonDescent
21+
using SciMLBase: SciMLBase, AbstractNonlinearProblem, NLStats, ReturnCode
22+
using SciMLOperators: AbstractSciMLOperator
23+
using StaticArraysCore: StaticArray, Size, MArray
24+
25+
include("raphson.jl")
26+
include("gauss_newton.jl")
27+
include("levenberg_marquardt.jl")
28+
include("trust_region.jl")
29+
include("pseudo_transient.jl")
30+
31+
include("solve.jl")
32+
33+
@reexport using SciMLBase, NonlinearSolveBase
34+
35+
export NewtonRaphson, PseudoTransient
36+
export GaussNewton, LevenbergMarquardt, TrustRegion
37+
38+
export GeneralizedFirstOrderAlgorithm
39+
340
end

lib/NonlinearSolveFirstOrder/src/gauss_newton.jl

Whitespace-only changes.

lib/NonlinearSolveFirstOrder/src/levenberg_marquardt.jl

Whitespace-only changes.

lib/NonlinearSolveFirstOrder/src/pseudo_transient.jl

Whitespace-only changes.

lib/NonlinearSolveFirstOrder/src/raphson.jl

Whitespace-only changes.
+303
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
"""
2+
GeneralizedFirstOrderAlgorithm(;
3+
descent, linesearch = missing,
4+
trustregion = missing, autodiff = nothing, vjp_autodiff = nothing,
5+
jvp_autodiff = nothing, max_shrink_times::Int = typemax(Int),
6+
concrete_jac = Val(false), name::Symbol = :unknown
7+
)
8+
9+
This is a Generalization of First-Order (uses Jacobian) Nonlinear Solve Algorithms. The most
10+
common example of this is Newton-Raphson Method.
11+
12+
First Order here refers to the order of differentiation, and should not be confused with the
13+
order of convergence.
14+
15+
### Keyword Arguments
16+
17+
- `trustregion`: Globalization using a Trust Region Method. This needs to follow the
18+
[`NonlinearSolve.AbstractTrustRegionMethod`](@ref) interface.
19+
- `descent`: The descent method to use to compute the step. This needs to follow the
20+
[`NonlinearSolve.AbstractDescentAlgorithm`](@ref) interface.
21+
- `max_shrink_times`: The maximum number of times the trust region radius can be shrunk
22+
before the algorithm terminates.
23+
"""
24+
@concrete struct GeneralizedFirstOrderAlgorithm <: AbstractNonlinearSolveAlgorithm
25+
linesearch
26+
trustregion
27+
descent
28+
max_shrink_times::Int
29+
30+
autodiff
31+
vjp_autodiff
32+
jvp_autodiff
33+
34+
concrete_jac <: Union{Val{false}, Val{true}}
35+
name::Symbol
36+
end
37+
38+
function GeneralizedFirstOrderAlgorithm(;
39+
descent, linesearch = missing, trustregion = missing, autodiff = nothing,
40+
vjp_autodiff = nothing, jvp_autodiff = nothing, max_shrink_times::Int = typemax(Int),
41+
concrete_jac = Val(false), name::Symbol = :unknown)
42+
return GeneralizedFirstOrderAlgorithm(
43+
linesearch, trustregion, descent, max_shrink_times,
44+
autodiff, vjp_autodiff, jvp_autodiff,
45+
concrete_jac, name
46+
)
47+
end
48+
49+
@concrete mutable struct GeneralizedFirstOrderAlgorithmCache <: AbstractNonlinearSolveCache
50+
# Basic Requirements
51+
fu
52+
u
53+
u_cache
54+
p
55+
du # Aliased to `get_du(descent_cache)`
56+
J # Aliased to `jac_cache.J`
57+
alg <: GeneralizedFirstOrderAlgorithm
58+
prob <: AbstractNonlinearProblem
59+
globalization <: Union{Val{:LineSearch}, Val{:TrustRegion}, Val{:None}}
60+
61+
# Internal Caches
62+
jac_cache
63+
descent_cache
64+
linesearch_cache
65+
trustregion_cache
66+
67+
# Counters
68+
stats::NLStats
69+
nsteps::Int
70+
maxiters::Int
71+
maxtime
72+
max_shrink_times::Int
73+
74+
# Timer
75+
timer
76+
total_time::Float64
77+
78+
# State Affect
79+
make_new_jacobian::Bool
80+
81+
# Termination & Tracking
82+
termination_cache
83+
trace
84+
retcode::ReturnCode.T
85+
force_stop::Bool
86+
kwargs
87+
end
88+
89+
# XXX: Implement
90+
# function __reinit_internal!(
91+
# cache::GeneralizedFirstOrderAlgorithmCache{iip}, args...; p = cache.p, u0 = cache.u,
92+
# alias_u0::Bool = false, maxiters = 1000, maxtime = nothing, kwargs...) where {iip}
93+
# if iip
94+
# recursivecopy!(cache.u, u0)
95+
# cache.prob.f(cache.fu, cache.u, p)
96+
# else
97+
# cache.u = __maybe_unaliased(u0, alias_u0)
98+
# set_fu!(cache, cache.prob.f(cache.u, p))
99+
# end
100+
# cache.p = p
101+
102+
# __reinit_internal!(cache.stats)
103+
# cache.nsteps = 0
104+
# cache.maxiters = maxiters
105+
# cache.maxtime = maxtime
106+
# cache.total_time = 0.0
107+
# cache.force_stop = false
108+
# cache.retcode = ReturnCode.Default
109+
# cache.make_new_jacobian = true
110+
111+
# reset!(cache.trace)
112+
# reinit!(cache.termination_cache, get_fu(cache), get_u(cache); kwargs...)
113+
# reset_timer!(cache.timer)
114+
# end
115+
116+
NonlinearSolveBase.@internal_caches(GeneralizedFirstOrderAlgorithmCache,
117+
:jac_cache, :descent_cache, :linesearch_cache, :trustregion_cache)
118+
119+
# function SciMLBase.__init(
120+
# prob::AbstractNonlinearProblem{uType, iip}, alg::GeneralizedFirstOrderAlgorithm,
121+
# args...; stats = empty_nlstats(), alias_u0 = false, maxiters = 1000,
122+
# abstol = nothing, reltol = nothing, maxtime = nothing,
123+
# termination_condition = nothing, internalnorm = L2_NORM,
124+
# linsolve_kwargs = (;), kwargs...) where {uType, iip}
125+
# autodiff = select_jacobian_autodiff(prob, alg.autodiff)
126+
# jvp_autodiff = if alg.jvp_autodiff === nothing && alg.autodiff !== nothing &&
127+
# (ADTypes.mode(alg.autodiff) isa ADTypes.ForwardMode ||
128+
# ADTypes.mode(alg.autodiff) isa ADTypes.ForwardOrReverseMode)
129+
# select_forward_mode_autodiff(prob, alg.autodiff)
130+
# else
131+
# select_forward_mode_autodiff(prob, alg.jvp_autodiff)
132+
# end
133+
# vjp_autodiff = if alg.vjp_autodiff === nothing && alg.autodiff !== nothing &&
134+
# (ADTypes.mode(alg.autodiff) isa ADTypes.ReverseMode ||
135+
# ADTypes.mode(alg.autodiff) isa ADTypes.ForwardOrReverseMode)
136+
# select_reverse_mode_autodiff(prob, alg.autodiff)
137+
# else
138+
# select_reverse_mode_autodiff(prob, alg.vjp_autodiff)
139+
# end
140+
141+
# timer = get_timer_output()
142+
# @static_timeit timer "cache construction" begin
143+
# (; f, u0, p) = prob
144+
# u = __maybe_unaliased(u0, alias_u0)
145+
# fu = evaluate_f(prob, u)
146+
# @bb u_cache = copy(u)
147+
148+
# linsolve = get_linear_solver(alg.descent)
149+
150+
# abstol, reltol, termination_cache = NonlinearSolveBase.init_termination_cache(
151+
# prob, abstol, reltol, fu, u, termination_condition, Val(:regular))
152+
# linsolve_kwargs = merge((; abstol, reltol), linsolve_kwargs)
153+
154+
# jac_cache = construct_jacobian_cache(
155+
# prob, alg, f, fu, u, p; stats, autodiff, linsolve, jvp_autodiff, vjp_autodiff)
156+
# J = jac_cache(nothing)
157+
158+
# descent_cache = __internal_init(prob, alg.descent, J, fu, u; stats, abstol,
159+
# reltol, internalnorm, linsolve_kwargs, timer)
160+
# du = get_du(descent_cache)
161+
162+
# has_linesearch = alg.linesearch !== missing && alg.linesearch !== nothing
163+
# has_trustregion = alg.trustregion !== missing && alg.trustregion !== nothing
164+
165+
# if has_trustregion && has_linesearch
166+
# error("TrustRegion and LineSearch methods are algorithmically incompatible.")
167+
# end
168+
169+
# GB = :None
170+
# linesearch_cache = nothing
171+
# trustregion_cache = nothing
172+
173+
# if has_trustregion
174+
# supports_trust_region(alg.descent) || error("Trust Region not supported by \
175+
# $(alg.descent).")
176+
# trustregion_cache = __internal_init(
177+
# prob, alg.trustregion, f, fu, u, p; stats, internalnorm, kwargs...,
178+
# autodiff, jvp_autodiff, vjp_autodiff)
179+
# GB = :TrustRegion
180+
# end
181+
182+
# if has_linesearch
183+
# supports_line_search(alg.descent) || error("Line Search not supported by \
184+
# $(alg.descent).")
185+
# linesearch_cache = init(
186+
# prob, alg.linesearch, fu, u; stats, autodiff = jvp_autodiff, kwargs...)
187+
# GB = :LineSearch
188+
# end
189+
190+
# trace = init_nonlinearsolve_trace(
191+
# prob, alg, u, fu, ApplyArray(__zero, J), du; kwargs...)
192+
193+
# return GeneralizedFirstOrderAlgorithmCache{iip, GB, maxtime !== nothing}(
194+
# fu, u, u_cache, p, du, J, alg, prob, jac_cache, descent_cache, linesearch_cache,
195+
# trustregion_cache, stats, 0, maxiters, maxtime, alg.max_shrink_times, timer,
196+
# 0.0, true, termination_cache, trace, ReturnCode.Default, false, kwargs)
197+
# end
198+
# end
199+
200+
# function __step!(cache::GeneralizedFirstOrderAlgorithmCache{iip, GB};
201+
# recompute_jacobian::Union{Nothing, Bool} = nothing, kwargs...) where {iip, GB}
202+
# @static_timeit cache.timer "jacobian" begin
203+
# if (recompute_jacobian === nothing || recompute_jacobian) && cache.make_new_jacobian
204+
# J = cache.jac_cache(cache.u)
205+
# new_jacobian = true
206+
# else
207+
# J = cache.jac_cache(nothing)
208+
# new_jacobian = false
209+
# end
210+
# end
211+
212+
# @static_timeit cache.timer "descent" begin
213+
# if cache.trustregion_cache !== nothing &&
214+
# hasfield(typeof(cache.trustregion_cache), :trust_region)
215+
# descent_result = __internal_solve!(
216+
# cache.descent_cache, J, cache.fu, cache.u; new_jacobian,
217+
# trust_region = cache.trustregion_cache.trust_region, cache.kwargs...)
218+
# else
219+
# descent_result = __internal_solve!(
220+
# cache.descent_cache, J, cache.fu, cache.u; new_jacobian, cache.kwargs...)
221+
# end
222+
# end
223+
224+
# if !descent_result.linsolve_success
225+
# if new_jacobian
226+
# # Jacobian Information is current and linear solve failed terminate the solve
227+
# cache.retcode = ReturnCode.InternalLinearSolveFailed
228+
# cache.force_stop = true
229+
# return
230+
# else
231+
# # Jacobian Information is not current and linear solve failed, recompute
232+
# # Jacobian
233+
# if !haskey(cache.kwargs, :verbose) || cache.kwargs[:verbose]
234+
# @warn "Linear Solve Failed but Jacobian Information is not current. \
235+
# Retrying with updated Jacobian."
236+
# end
237+
# # In the 2nd call the `new_jacobian` is guaranteed to be `true`.
238+
# cache.make_new_jacobian = true
239+
# __step!(cache; recompute_jacobian = true, kwargs...)
240+
# return
241+
# end
242+
# end
243+
244+
# δu, descent_intermediates = descent_result.δu, descent_result.extras
245+
246+
# if descent_result.success
247+
# cache.make_new_jacobian = true
248+
# if GB === :LineSearch
249+
# @static_timeit cache.timer "linesearch" begin
250+
# linesearch_sol = solve!(cache.linesearch_cache, cache.u, δu)
251+
# linesearch_failed = !SciMLBase.successful_retcode(linesearch_sol.retcode)
252+
# α = linesearch_sol.step_size
253+
# end
254+
# if linesearch_failed
255+
# cache.retcode = ReturnCode.InternalLineSearchFailed
256+
# cache.force_stop = true
257+
# end
258+
# @static_timeit cache.timer "step" begin
259+
# @bb axpy!(α, δu, cache.u)
260+
# evaluate_f!(cache, cache.u, cache.p)
261+
# end
262+
# elseif GB === :TrustRegion
263+
# @static_timeit cache.timer "trustregion" begin
264+
# tr_accepted, u_new, fu_new = __internal_solve!(
265+
# cache.trustregion_cache, J, cache.fu,
266+
# cache.u, δu, descent_intermediates)
267+
# if tr_accepted
268+
# @bb copyto!(cache.u, u_new)
269+
# @bb copyto!(cache.fu, fu_new)
270+
# α = true
271+
# else
272+
# α = false
273+
# cache.make_new_jacobian = false
274+
# end
275+
# if hasfield(typeof(cache.trustregion_cache), :shrink_counter) &&
276+
# cache.trustregion_cache.shrink_counter > cache.max_shrink_times
277+
# cache.retcode = ReturnCode.ShrinkThresholdExceeded
278+
# cache.force_stop = true
279+
# end
280+
# end
281+
# elseif GB === :None
282+
# @static_timeit cache.timer "step" begin
283+
# @bb axpy!(1, δu, cache.u)
284+
# evaluate_f!(cache, cache.u, cache.p)
285+
# end
286+
# α = true
287+
# else
288+
# error("Unknown Globalization Strategy: $(GB). Allowed values are (:LineSearch, \
289+
# :TrustRegion, :None)")
290+
# end
291+
# check_and_update!(cache, cache.fu, cache.u, cache.u_cache)
292+
# else
293+
# α = false
294+
# cache.make_new_jacobian = false
295+
# end
296+
297+
# update_trace!(cache, α)
298+
# @bb copyto!(cache.u_cache, cache.u)
299+
300+
# callback_into_cache!(cache)
301+
302+
# return nothing
303+
# end

lib/NonlinearSolveFirstOrder/src/trust_region.jl

Whitespace-only changes.

0 commit comments

Comments
 (0)