Skip to content

Commit 87c3ddc

Browse files
Fix return codes for nonlinear least squares
Fixes #459. The crux of the issue is that `f(x) = residual` only applies in the NonlinearProblem and SteadyStateProblem cases. When `f(x)` is a nonlinear least squares problem, finding a local minima is a solution, not a failure of the algorithm. Thus this reclassifies Stalled in NLLSQ to StalledSuccess, which makes it a successful return. Algorithms which require the NonlinearLeastSquares solution to have `||resid|| < tol` thus need to be careful with the return handling, as is done in the PR that introduces this return code SciML/SciMLBase.jl#1016. However, that's a fairly odd case because it's feasibility checking, while the normal use case for NLLSQ is for optimization, and in an optimization case there's no reason to believe you should always have a solution close to zero.
1 parent 73f7d1f commit 87c3ddc

File tree

4 files changed

+36
-9
lines changed

4 files changed

+36
-9
lines changed

lib/NonlinearSolveBase/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolveBase"
22
uuid = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
33
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
4-
version = "1.7.0"
4+
version = "1.8.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -68,7 +68,7 @@ MaybeInplace = "0.1.4"
6868
Preferences = "1.4"
6969
Printf = "1.10"
7070
RecursiveArrayTools = "3"
71-
SciMLBase = "2.69"
71+
SciMLBase = "2.89"
7272
SciMLJacobianOperators = "0.1.1"
7373
SciMLOperators = "0.3.13, 0.4"
7474
SparseArrays = "1.10"

lib/NonlinearSolveBase/src/polyalg.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,10 @@ end
160160
InternalAPI.step!($(cache_syms[i]), args...; kwargs...)
161161
$(cache_syms[i]).nsteps += 1
162162
if !NonlinearSolveBase.not_terminated($(cache_syms[i]))
163-
if SciMLBase.successful_retcode($(cache_syms[i]).retcode)
163+
# If a NonlinearLeastSquaresProblem StalledSuccess, try the next
164+
# solver to see if you get a lower residual
165+
if SciMLBase.successful_retcode($(cache_syms[i]).retcode) &&
166+
$(cache_syms[i]).retcode != ReturnCode.StalledSuccess
164167
cache.best = $(i)
165168
cache.force_stop = true
166169
cache.retcode = $(cache_syms[i]).retcode

lib/NonlinearSolveBase/src/termination_conditions.jl

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ const AbsNormModes = Union{
2121
step_norm_trace
2222
max_stalled_steps
2323
u_diff_cache::uType
24+
leastsq::Bool
2425
end
2526

2627
get_abstol(cache::NonlinearTerminationModeCache) = cache.abstol
@@ -36,7 +37,7 @@ function update_u!!(cache::NonlinearTerminationModeCache, u)
3637
end
3738

3839
function CommonSolve.init(
39-
::AbstractNonlinearProblem, mode::AbstractNonlinearTerminationMode, du, u,
40+
prob::AbstractNonlinearProblem, mode::AbstractNonlinearTerminationMode, du, u,
4041
saved_value_prototype...; abstol = nothing, reltol = nothing, kwargs...
4142
)
4243
T = promote_type(eltype(du), eltype(u))
@@ -80,10 +81,12 @@ function CommonSolve.init(
8081

8182
length(saved_value_prototype) == 0 && (saved_value_prototype = nothing)
8283

84+
leastsq = typeof(prob) <: NonlinearLeastSquaresProblem
85+
8386
return NonlinearTerminationModeCache(
8487
u_unaliased, ReturnCode.Default, abstol, reltol, best_value, mode,
8588
initial_objective, objectives_trace, 0, saved_value_prototype,
86-
u0_norm, step_norm_trace, max_stalled_steps, u_diff_cache
89+
u0_norm, step_norm_trace, max_stalled_steps, u_diff_cache, leastsq
8790
)
8891
end
8992

@@ -146,6 +149,7 @@ end
146149
function (cache::NonlinearTerminationModeCache)(
147150
mode::AbstractSafeNonlinearTerminationMode, du, u, uprev, abstol, reltol, args...
148151
)
152+
149153
if mode isa AbsNormSafeTerminationMode || mode isa AbsNormSafeBestTerminationMode
150154
objective = Utils.apply_norm(mode.internalnorm, du)
151155
criteria = abstol
@@ -177,7 +181,7 @@ function (cache::NonlinearTerminationModeCache)(
177181
end
178182

179183
# Main Termination Criteria
180-
if objective criteria
184+
if !cache.leastsq && objective criteria
181185
cache.retcode = ReturnCode.Success
182186
return true
183187
end
@@ -195,7 +199,13 @@ function (cache::NonlinearTerminationModeCache)(
195199
min_obj, max_obj = extrema(cache.objectives_trace)
196200
end
197201
if min_obj < mode.min_max_factor * max_obj
198-
cache.retcode = ReturnCode.Stalled
202+
if cache.leastsq
203+
# If least squares, found a local minima thus success
204+
cache.retcode = ReturnCode.StalledSuccess
205+
else
206+
# Not a success if f(x)>0 and residual too high
207+
cache.retcode = ReturnCode.Stalled
208+
end
199209
return true
200210
end
201211
end
@@ -209,7 +219,7 @@ function (cache::NonlinearTerminationModeCache)(
209219
end
210220
du_norm = L2_NORM(cache.u_diff_cache)
211221
cache.step_norm_trace[mod1(cache.nsteps, length(cache.step_norm_trace))] = du_norm
212-
if cache.nsteps > mode.max_stalled_steps
222+
if cache.nsteps > mode.max_stalled_steps || iszero(du_norm)
213223
max_step_norm = maximum(cache.step_norm_trace)
214224
if mode isa AbsNormSafeTerminationMode ||
215225
mode isa AbsNormSafeBestTerminationMode
@@ -218,7 +228,11 @@ function (cache::NonlinearTerminationModeCache)(
218228
stalled_step = max_step_norm reltol * (max_step_norm + cache.u0_norm)
219229
end
220230
if stalled_step
221-
cache.retcode = ReturnCode.Stalled
231+
if cache.leastsq
232+
cache.retcode = ReturnCode.StalledSuccess
233+
else
234+
cache.retcode = ReturnCode.Stalled
235+
end
222236
return true
223237
end
224238
end

test/core_tests.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,3 +427,13 @@ end
427427

428428
@test !(solve(nlprob, NewtonRaphson()).alg.autodiff isa AutoPolyesterForwardDiff)
429429
end
430+
431+
@testitem "NonlinearLeastSquares ReturnCode" tags=[:core] begin
432+
f(u,p) = [1.0]
433+
nlf = NonlinearFunction(f; resid_prototype=zeros(1))
434+
prob = NonlinearLeastSquaresProblem(nlf, [1.0])
435+
sol = solve(prob)
436+
@test SciMLBase.successful_retcode(sol)
437+
@test sol.retcode == ReturnCode.StalledSuccess
438+
@test sol.stats.nf == 3
439+
end

0 commit comments

Comments
 (0)