89
89
kwargs
90
90
end
91
91
92
- # XXX : Implement
93
- # function __reinit_internal!(
94
- # cache::GeneralizedFirstOrderAlgorithmCache{iip}, args...; p = cache.p, u0 = cache.u,
95
- # alias_u0::Bool = false, maxiters = 1000, maxtime = nothing, kwargs...) where {iip}
96
- # if iip
97
- # recursivecopy!(cache.u, u0)
98
- # cache.prob.f(cache.fu, cache.u, p)
99
- # else
100
- # cache.u = __maybe_unaliased(u0, alias_u0)
101
- # set_fu!(cache, cache.prob.f(cache.u, p))
102
- # end
103
- # cache.p = p
104
-
105
- # __reinit_internal!(cache.stats)
106
- # cache.nsteps = 0
107
- # cache.maxiters = maxiters
108
- # cache.maxtime = maxtime
109
- # cache.total_time = 0.0
110
- # cache.force_stop = false
111
- # cache.retcode = ReturnCode.Default
112
- # cache.make_new_jacobian = true
113
-
114
- # reset!(cache.trace)
115
- # reinit!(cache.termination_cache, get_fu(cache), get_u(cache); kwargs...)
116
- # reset_timer!(cache.timer)
117
- # end
92
+ function InternalAPI. reinit_self! (
93
+ cache:: GeneralizedFirstOrderAlgorithmCache , args... ; p = cache. p, u0 = cache. u,
94
+ alias_u0:: Bool = false , maxiters = 1000 , maxtime = nothing , kwargs...
95
+ )
96
+ Utils. reinit_common! (cache, u0, p, alias_u0)
97
+
98
+ InternalAPI. reinit! (cache. stats)
99
+ cache. nsteps = 0
100
+ cache. maxiters = maxiters
101
+ cache. maxtime = maxtime
102
+ cache. total_time = 0.0
103
+ cache. force_stop = false
104
+ cache. retcode = ReturnCode. Default
105
+ cache. make_new_jacobian = true
106
+
107
+ NonlinearSolveBase. reset! (cache. trace)
108
+ SciMLBase. reinit! (
109
+ cache. termination_cache, NonlinearSolveBase. get_fu (cache),
110
+ NonlinearSolveBase. get_u (cache); kwargs...
111
+ )
112
+ NonlinearSolveBase. reset_timer! (cache. timer)
113
+ return
114
+ end
118
115
119
116
NonlinearSolveBase. @internal_caches (GeneralizedFirstOrderAlgorithmCache,
120
117
:jac_cache , :descent_cache , :linesearch_cache , :trustregion_cache )
121
118
122
119
function SciMLBase. __init (
123
- prob:: AbstractNonlinearProblem , alg:: GeneralizedFirstOrderAlgorithm ,
124
- args ... ; stats = NLStats (0 , 0 , 0 , 0 , 0 ), alias_u0 = false , maxiters = 1000 ,
120
+ prob:: AbstractNonlinearProblem , alg:: GeneralizedFirstOrderAlgorithm , args ... ;
121
+ stats = NLStats (0 , 0 , 0 , 0 , 0 ), alias_u0 = false , maxiters = 1000 ,
125
122
abstol = nothing , reltol = nothing , maxtime = nothing ,
126
123
termination_condition = nothing , internalnorm = L2_NORM,
127
124
linsolve_kwargs = (;), kwargs...
128
125
)
129
126
@set! alg. autodiff = NonlinearSolveBase. select_jacobian_autodiff (prob, alg. autodiff)
130
- @set! alg. jvp_autodiff = if alg. jvp_autodiff === nothing && alg. autodiff != = nothing &&
127
+ provided_jvp_autodiff = alg. jvp_autodiff != = nothing
128
+ @set! alg. jvp_autodiff = if ! provided_jvp_autodiff && alg. autodiff != = nothing &&
131
129
(ADTypes. mode (alg. autodiff) isa ADTypes. ForwardMode ||
132
130
ADTypes. mode (alg. autodiff) isa
133
131
ADTypes. ForwardOrReverseMode)
134
132
NonlinearSolveBase. select_forward_mode_autodiff (prob, alg. autodiff)
135
133
else
136
134
NonlinearSolveBase. select_forward_mode_autodiff (prob, alg. jvp_autodiff)
137
135
end
138
- @set! alg. vjp_autodiff = if alg. vjp_autodiff === nothing && alg. autodiff != = nothing &&
136
+ provided_vjp_autodiff = alg. vjp_autodiff != = nothing
137
+ @set! alg. vjp_autodiff = if ! provided_vjp_autodiff && alg. autodiff != = nothing &&
139
138
(ADTypes. mode (alg. autodiff) isa ADTypes. ReverseMode ||
140
139
ADTypes. mode (alg. autodiff) isa
141
140
ADTypes. ForwardOrReverseMode)
@@ -185,7 +184,7 @@ function SciMLBase.__init(
185
184
error (" Trust Region not supported by $(alg. descent) ." )
186
185
trustregion_cache = InternalAPI. init (
187
186
prob, alg. trustregion, prob. f, fu, u, prob. p;
188
- stats, internalnorm, kwargs...
187
+ alg . vjp_autodiff, alg . jvp_autodiff, stats, internalnorm, kwargs...
189
188
)
190
189
globalization = Val (:TrustRegion )
191
190
end
@@ -194,7 +193,11 @@ function SciMLBase.__init(
194
193
NonlinearSolveBase. supports_line_search (alg. descent) ||
195
194
error (" Line Search not supported by $(alg. descent) ." )
196
195
linesearch_cache = CommonSolve. init (
197
- prob, alg. linesearch, fu, u; stats, internalnorm, kwargs...
196
+ prob, alg. linesearch, fu, u; stats, internalnorm,
197
+ autodiff = ifelse (
198
+ provided_jvp_autodiff, alg. jvp_autodiff, alg. vjp_autodiff
199
+ ),
200
+ kwargs...
198
201
)
199
202
globalization = Val (:LineSearch )
200
203
end
0 commit comments