|
| 1 | +""" |
| 2 | + GeneralizedFirstOrderAlgorithm(; |
| 3 | + descent, linesearch = missing, |
| 4 | + trustregion = missing, autodiff = nothing, vjp_autodiff = nothing, |
| 5 | + jvp_autodiff = nothing, max_shrink_times::Int = typemax(Int), |
| 6 | + concrete_jac = Val(false), name::Symbol = :unknown |
| 7 | + ) |
| 8 | +
|
| 9 | +This is a Generalization of First-Order (uses Jacobian) Nonlinear Solve Algorithms. The most |
| 10 | +common example of this is Newton-Raphson Method. |
| 11 | +
|
| 12 | +First Order here refers to the order of differentiation, and should not be confused with the |
| 13 | +order of convergence. |
| 14 | +
|
| 15 | +### Keyword Arguments |
| 16 | +
|
| 17 | + - `trustregion`: Globalization using a Trust Region Method. This needs to follow the |
| 18 | + [`NonlinearSolve.AbstractTrustRegionMethod`](@ref) interface. |
| 19 | + - `descent`: The descent method to use to compute the step. This needs to follow the |
| 20 | + [`NonlinearSolve.AbstractDescentAlgorithm`](@ref) interface. |
| 21 | + - `max_shrink_times`: The maximum number of times the trust region radius can be shrunk |
| 22 | + before the algorithm terminates. |
| 23 | +""" |
| 24 | +@concrete struct GeneralizedFirstOrderAlgorithm <: AbstractNonlinearSolveAlgorithm |
| 25 | + linesearch |
| 26 | + trustregion |
| 27 | + descent |
| 28 | + max_shrink_times::Int |
| 29 | + |
| 30 | + autodiff |
| 31 | + vjp_autodiff |
| 32 | + jvp_autodiff |
| 33 | + |
| 34 | + concrete_jac <: Union{Val{false}, Val{true}} |
| 35 | + name::Symbol |
| 36 | +end |
| 37 | + |
| 38 | +function GeneralizedFirstOrderAlgorithm(; |
| 39 | + descent, linesearch = missing, trustregion = missing, autodiff = nothing, |
| 40 | + vjp_autodiff = nothing, jvp_autodiff = nothing, max_shrink_times::Int = typemax(Int), |
| 41 | + concrete_jac = Val(false), name::Symbol = :unknown) |
| 42 | + return GeneralizedFirstOrderAlgorithm( |
| 43 | + linesearch, trustregion, descent, max_shrink_times, |
| 44 | + autodiff, vjp_autodiff, jvp_autodiff, |
| 45 | + concrete_jac, name |
| 46 | + ) |
| 47 | +end |
| 48 | + |
| 49 | +@concrete mutable struct GeneralizedFirstOrderAlgorithmCache <: AbstractNonlinearSolveCache |
| 50 | + # Basic Requirements |
| 51 | + fu |
| 52 | + u |
| 53 | + u_cache |
| 54 | + p |
| 55 | + du # Aliased to `get_du(descent_cache)` |
| 56 | + J # Aliased to `jac_cache.J` |
| 57 | + alg <: GeneralizedFirstOrderAlgorithm |
| 58 | + prob <: AbstractNonlinearProblem |
| 59 | + globalization <: Union{Val{:LineSearch}, Val{:TrustRegion}, Val{:None}} |
| 60 | + |
| 61 | + # Internal Caches |
| 62 | + jac_cache |
| 63 | + descent_cache |
| 64 | + linesearch_cache |
| 65 | + trustregion_cache |
| 66 | + |
| 67 | + # Counters |
| 68 | + stats::NLStats |
| 69 | + nsteps::Int |
| 70 | + maxiters::Int |
| 71 | + maxtime |
| 72 | + max_shrink_times::Int |
| 73 | + |
| 74 | + # Timer |
| 75 | + timer |
| 76 | + total_time::Float64 |
| 77 | + |
| 78 | + # State Affect |
| 79 | + make_new_jacobian::Bool |
| 80 | + |
| 81 | + # Termination & Tracking |
| 82 | + termination_cache |
| 83 | + trace |
| 84 | + retcode::ReturnCode.T |
| 85 | + force_stop::Bool |
| 86 | + kwargs |
| 87 | +end |
| 88 | + |
| 89 | +# XXX: Implement |
| 90 | +# function __reinit_internal!( |
| 91 | +# cache::GeneralizedFirstOrderAlgorithmCache{iip}, args...; p = cache.p, u0 = cache.u, |
| 92 | +# alias_u0::Bool = false, maxiters = 1000, maxtime = nothing, kwargs...) where {iip} |
| 93 | +# if iip |
| 94 | +# recursivecopy!(cache.u, u0) |
| 95 | +# cache.prob.f(cache.fu, cache.u, p) |
| 96 | +# else |
| 97 | +# cache.u = __maybe_unaliased(u0, alias_u0) |
| 98 | +# set_fu!(cache, cache.prob.f(cache.u, p)) |
| 99 | +# end |
| 100 | +# cache.p = p |
| 101 | + |
| 102 | +# __reinit_internal!(cache.stats) |
| 103 | +# cache.nsteps = 0 |
| 104 | +# cache.maxiters = maxiters |
| 105 | +# cache.maxtime = maxtime |
| 106 | +# cache.total_time = 0.0 |
| 107 | +# cache.force_stop = false |
| 108 | +# cache.retcode = ReturnCode.Default |
| 109 | +# cache.make_new_jacobian = true |
| 110 | + |
| 111 | +# reset!(cache.trace) |
| 112 | +# reinit!(cache.termination_cache, get_fu(cache), get_u(cache); kwargs...) |
| 113 | +# reset_timer!(cache.timer) |
| 114 | +# end |
| 115 | + |
| 116 | +NonlinearSolveBase.@internal_caches(GeneralizedFirstOrderAlgorithmCache, |
| 117 | + :jac_cache, :descent_cache, :linesearch_cache, :trustregion_cache) |
| 118 | + |
| 119 | +# function SciMLBase.__init( |
| 120 | +# prob::AbstractNonlinearProblem{uType, iip}, alg::GeneralizedFirstOrderAlgorithm, |
| 121 | +# args...; stats = empty_nlstats(), alias_u0 = false, maxiters = 1000, |
| 122 | +# abstol = nothing, reltol = nothing, maxtime = nothing, |
| 123 | +# termination_condition = nothing, internalnorm = L2_NORM, |
| 124 | +# linsolve_kwargs = (;), kwargs...) where {uType, iip} |
| 125 | +# autodiff = select_jacobian_autodiff(prob, alg.autodiff) |
| 126 | +# jvp_autodiff = if alg.jvp_autodiff === nothing && alg.autodiff !== nothing && |
| 127 | +# (ADTypes.mode(alg.autodiff) isa ADTypes.ForwardMode || |
| 128 | +# ADTypes.mode(alg.autodiff) isa ADTypes.ForwardOrReverseMode) |
| 129 | +# select_forward_mode_autodiff(prob, alg.autodiff) |
| 130 | +# else |
| 131 | +# select_forward_mode_autodiff(prob, alg.jvp_autodiff) |
| 132 | +# end |
| 133 | +# vjp_autodiff = if alg.vjp_autodiff === nothing && alg.autodiff !== nothing && |
| 134 | +# (ADTypes.mode(alg.autodiff) isa ADTypes.ReverseMode || |
| 135 | +# ADTypes.mode(alg.autodiff) isa ADTypes.ForwardOrReverseMode) |
| 136 | +# select_reverse_mode_autodiff(prob, alg.autodiff) |
| 137 | +# else |
| 138 | +# select_reverse_mode_autodiff(prob, alg.vjp_autodiff) |
| 139 | +# end |
| 140 | + |
| 141 | +# timer = get_timer_output() |
| 142 | +# @static_timeit timer "cache construction" begin |
| 143 | +# (; f, u0, p) = prob |
| 144 | +# u = __maybe_unaliased(u0, alias_u0) |
| 145 | +# fu = evaluate_f(prob, u) |
| 146 | +# @bb u_cache = copy(u) |
| 147 | + |
| 148 | +# linsolve = get_linear_solver(alg.descent) |
| 149 | + |
| 150 | +# abstol, reltol, termination_cache = NonlinearSolveBase.init_termination_cache( |
| 151 | +# prob, abstol, reltol, fu, u, termination_condition, Val(:regular)) |
| 152 | +# linsolve_kwargs = merge((; abstol, reltol), linsolve_kwargs) |
| 153 | + |
| 154 | +# jac_cache = construct_jacobian_cache( |
| 155 | +# prob, alg, f, fu, u, p; stats, autodiff, linsolve, jvp_autodiff, vjp_autodiff) |
| 156 | +# J = jac_cache(nothing) |
| 157 | + |
| 158 | +# descent_cache = __internal_init(prob, alg.descent, J, fu, u; stats, abstol, |
| 159 | +# reltol, internalnorm, linsolve_kwargs, timer) |
| 160 | +# du = get_du(descent_cache) |
| 161 | + |
| 162 | +# has_linesearch = alg.linesearch !== missing && alg.linesearch !== nothing |
| 163 | +# has_trustregion = alg.trustregion !== missing && alg.trustregion !== nothing |
| 164 | + |
| 165 | +# if has_trustregion && has_linesearch |
| 166 | +# error("TrustRegion and LineSearch methods are algorithmically incompatible.") |
| 167 | +# end |
| 168 | + |
| 169 | +# GB = :None |
| 170 | +# linesearch_cache = nothing |
| 171 | +# trustregion_cache = nothing |
| 172 | + |
| 173 | +# if has_trustregion |
| 174 | +# supports_trust_region(alg.descent) || error("Trust Region not supported by \ |
| 175 | +# $(alg.descent).") |
| 176 | +# trustregion_cache = __internal_init( |
| 177 | +# prob, alg.trustregion, f, fu, u, p; stats, internalnorm, kwargs..., |
| 178 | +# autodiff, jvp_autodiff, vjp_autodiff) |
| 179 | +# GB = :TrustRegion |
| 180 | +# end |
| 181 | + |
| 182 | +# if has_linesearch |
| 183 | +# supports_line_search(alg.descent) || error("Line Search not supported by \ |
| 184 | +# $(alg.descent).") |
| 185 | +# linesearch_cache = init( |
| 186 | +# prob, alg.linesearch, fu, u; stats, autodiff = jvp_autodiff, kwargs...) |
| 187 | +# GB = :LineSearch |
| 188 | +# end |
| 189 | + |
| 190 | +# trace = init_nonlinearsolve_trace( |
| 191 | +# prob, alg, u, fu, ApplyArray(__zero, J), du; kwargs...) |
| 192 | + |
| 193 | +# return GeneralizedFirstOrderAlgorithmCache{iip, GB, maxtime !== nothing}( |
| 194 | +# fu, u, u_cache, p, du, J, alg, prob, jac_cache, descent_cache, linesearch_cache, |
| 195 | +# trustregion_cache, stats, 0, maxiters, maxtime, alg.max_shrink_times, timer, |
| 196 | +# 0.0, true, termination_cache, trace, ReturnCode.Default, false, kwargs) |
| 197 | +# end |
| 198 | +# end |
| 199 | + |
| 200 | +# function __step!(cache::GeneralizedFirstOrderAlgorithmCache{iip, GB}; |
| 201 | +# recompute_jacobian::Union{Nothing, Bool} = nothing, kwargs...) where {iip, GB} |
| 202 | +# @static_timeit cache.timer "jacobian" begin |
| 203 | +# if (recompute_jacobian === nothing || recompute_jacobian) && cache.make_new_jacobian |
| 204 | +# J = cache.jac_cache(cache.u) |
| 205 | +# new_jacobian = true |
| 206 | +# else |
| 207 | +# J = cache.jac_cache(nothing) |
| 208 | +# new_jacobian = false |
| 209 | +# end |
| 210 | +# end |
| 211 | + |
| 212 | +# @static_timeit cache.timer "descent" begin |
| 213 | +# if cache.trustregion_cache !== nothing && |
| 214 | +# hasfield(typeof(cache.trustregion_cache), :trust_region) |
| 215 | +# descent_result = __internal_solve!( |
| 216 | +# cache.descent_cache, J, cache.fu, cache.u; new_jacobian, |
| 217 | +# trust_region = cache.trustregion_cache.trust_region, cache.kwargs...) |
| 218 | +# else |
| 219 | +# descent_result = __internal_solve!( |
| 220 | +# cache.descent_cache, J, cache.fu, cache.u; new_jacobian, cache.kwargs...) |
| 221 | +# end |
| 222 | +# end |
| 223 | + |
| 224 | +# if !descent_result.linsolve_success |
| 225 | +# if new_jacobian |
| 226 | +# # Jacobian Information is current and linear solve failed terminate the solve |
| 227 | +# cache.retcode = ReturnCode.InternalLinearSolveFailed |
| 228 | +# cache.force_stop = true |
| 229 | +# return |
| 230 | +# else |
| 231 | +# # Jacobian Information is not current and linear solve failed, recompute |
| 232 | +# # Jacobian |
| 233 | +# if !haskey(cache.kwargs, :verbose) || cache.kwargs[:verbose] |
| 234 | +# @warn "Linear Solve Failed but Jacobian Information is not current. \ |
| 235 | +# Retrying with updated Jacobian." |
| 236 | +# end |
| 237 | +# # In the 2nd call the `new_jacobian` is guaranteed to be `true`. |
| 238 | +# cache.make_new_jacobian = true |
| 239 | +# __step!(cache; recompute_jacobian = true, kwargs...) |
| 240 | +# return |
| 241 | +# end |
| 242 | +# end |
| 243 | + |
| 244 | +# δu, descent_intermediates = descent_result.δu, descent_result.extras |
| 245 | + |
| 246 | +# if descent_result.success |
| 247 | +# cache.make_new_jacobian = true |
| 248 | +# if GB === :LineSearch |
| 249 | +# @static_timeit cache.timer "linesearch" begin |
| 250 | +# linesearch_sol = solve!(cache.linesearch_cache, cache.u, δu) |
| 251 | +# linesearch_failed = !SciMLBase.successful_retcode(linesearch_sol.retcode) |
| 252 | +# α = linesearch_sol.step_size |
| 253 | +# end |
| 254 | +# if linesearch_failed |
| 255 | +# cache.retcode = ReturnCode.InternalLineSearchFailed |
| 256 | +# cache.force_stop = true |
| 257 | +# end |
| 258 | +# @static_timeit cache.timer "step" begin |
| 259 | +# @bb axpy!(α, δu, cache.u) |
| 260 | +# evaluate_f!(cache, cache.u, cache.p) |
| 261 | +# end |
| 262 | +# elseif GB === :TrustRegion |
| 263 | +# @static_timeit cache.timer "trustregion" begin |
| 264 | +# tr_accepted, u_new, fu_new = __internal_solve!( |
| 265 | +# cache.trustregion_cache, J, cache.fu, |
| 266 | +# cache.u, δu, descent_intermediates) |
| 267 | +# if tr_accepted |
| 268 | +# @bb copyto!(cache.u, u_new) |
| 269 | +# @bb copyto!(cache.fu, fu_new) |
| 270 | +# α = true |
| 271 | +# else |
| 272 | +# α = false |
| 273 | +# cache.make_new_jacobian = false |
| 274 | +# end |
| 275 | +# if hasfield(typeof(cache.trustregion_cache), :shrink_counter) && |
| 276 | +# cache.trustregion_cache.shrink_counter > cache.max_shrink_times |
| 277 | +# cache.retcode = ReturnCode.ShrinkThresholdExceeded |
| 278 | +# cache.force_stop = true |
| 279 | +# end |
| 280 | +# end |
| 281 | +# elseif GB === :None |
| 282 | +# @static_timeit cache.timer "step" begin |
| 283 | +# @bb axpy!(1, δu, cache.u) |
| 284 | +# evaluate_f!(cache, cache.u, cache.p) |
| 285 | +# end |
| 286 | +# α = true |
| 287 | +# else |
| 288 | +# error("Unknown Globalization Strategy: $(GB). Allowed values are (:LineSearch, \ |
| 289 | +# :TrustRegion, :None)") |
| 290 | +# end |
| 291 | +# check_and_update!(cache, cache.fu, cache.u, cache.u_cache) |
| 292 | +# else |
| 293 | +# α = false |
| 294 | +# cache.make_new_jacobian = false |
| 295 | +# end |
| 296 | + |
| 297 | +# update_trace!(cache, α) |
| 298 | +# @bb copyto!(cache.u_cache, cache.u) |
| 299 | + |
| 300 | +# callback_into_cache!(cache) |
| 301 | + |
| 302 | +# return nothing |
| 303 | +# end |
0 commit comments