Skip to content

Commit d0389b7

Browse files
committed
test: first order tests
1 parent 52a2387 commit d0389b7

26 files changed

+855
-695
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
name: CI (NonlinearSolveFirstOrder)
2+
3+
on:
4+
pull_request:
5+
branches:
6+
- master
7+
paths:
8+
- "lib/NonlinearSolveFirstOrder/**"
9+
- ".github/workflows/CI_NonlinearSolveFirstOrder.yml"
10+
- "lib/NonlinearSolveBase/**"
11+
- "lib/SciMLJacobianOperators/**"
12+
push:
13+
branches:
14+
- master
15+
16+
concurrency:
17+
# Skip intermediate builds: always.
18+
# Cancel intermediate builds: only if it is a pull request build.
19+
group: ${{ github.workflow }}-${{ github.ref }}
20+
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
21+
22+
jobs:
23+
test:
24+
runs-on: ${{ matrix.os }}
25+
strategy:
26+
fail-fast: false
27+
matrix:
28+
version:
29+
- "lts"
30+
- "1"
31+
os:
32+
- ubuntu-latest
33+
- macos-latest
34+
- windows-latest
35+
steps:
36+
- uses: actions/checkout@v4
37+
- uses: julia-actions/setup-julia@v2
38+
with:
39+
version: ${{ matrix.version }}
40+
- uses: actions/cache@v4
41+
env:
42+
cache-name: cache-artifacts
43+
with:
44+
path: ~/.julia/artifacts
45+
key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }}
46+
restore-keys: |
47+
${{ runner.os }}-test-${{ env.cache-name }}-
48+
${{ runner.os }}-test-
49+
${{ runner.os }}-
50+
- name: "Install Dependencies and Run Tests"
51+
run: |
52+
import Pkg
53+
Pkg.Registry.update()
54+
# Install packages present in subdirectories
55+
dev_pks = Pkg.PackageSpec[]
56+
for path in ("lib/SciMLJacobianOperators", "lib/NonlinearSolveBase")
57+
push!(dev_pks, Pkg.PackageSpec(; path))
58+
end
59+
Pkg.develop(dev_pks)
60+
Pkg.instantiate()
61+
Pkg.test(; coverage="user")
62+
shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/NonlinearSolveFirstOrder {0}
63+
- uses: julia-actions/julia-processcoverage@v1
64+
with:
65+
directories: lib/NonlinearSolveFirstOrder/src,lib/NonlinearSolveBase/src,lib/NonlinearSolveBase/ext,lib/SciMLJacobianOperators/src
66+
- uses: codecov/codecov-action@v4
67+
with:
68+
file: lcov.info
69+
token: ${{ secrets.CODECOV_TOKEN }}
70+
verbose: true
71+
fail_ci_if_error: true

common/common_nlls_testing.jl

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
using NonlinearSolveBase, SciMLBase, StableRNGs, ForwardDiff, Random, LinearAlgebra
2+
3+
true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4])
4+
true_function(y, x, θ) = (@. y = θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4]))
5+
6+
θ_true = [1.0, 0.1, 2.0, 0.5]
7+
x = [-1.0, -0.5, 0.0, 0.5, 1.0]
8+
9+
const y_target = true_function(x, θ_true)
10+
11+
function loss_function(θ, p)
12+
= true_function(p, θ)
13+
return.- y_target
14+
end
15+
16+
function loss_function(resid, θ, p)
17+
true_function(resid, p, θ)
18+
resid .= resid .- y_target
19+
return resid
20+
end
21+
22+
θ_init = θ_true .+ randn!(StableRNG(0), similar(θ_true)) * 0.1
23+
24+
function vjp(v, θ, p)
25+
resid = zeros(length(p))
26+
J = ForwardDiff.jacobian((resid, θ) -> loss_function(resid, θ, p), resid, θ)
27+
return vec(v' * J)
28+
end
29+
30+
function vjp!(Jv, v, θ, p)
31+
resid = zeros(length(p))
32+
J = ForwardDiff.jacobian((resid, θ) -> loss_function(resid, θ, p), resid, θ)
33+
mul!(vec(Jv), transpose(J), v)
34+
return nothing
35+
end
36+
37+
prob_oop = NonlinearLeastSquaresProblem{false}(loss_function, θ_init, x)
38+
prob_iip = NonlinearLeastSquaresProblem{true}(
39+
NonlinearFunction(loss_function; resid_prototype = zero(y_target)), θ_init, x
40+
)
41+
prob_oop_vjp = NonlinearLeastSquaresProblem(
42+
NonlinearFunction{false}(loss_function; vjp), θ_init, x
43+
)
44+
prob_iip_vjp = NonlinearLeastSquaresProblem(
45+
NonlinearFunction{true}(loss_function; resid_prototype = zero(y_target), vjp = vjp!),
46+
θ_init, x
47+
)
48+
49+
export prob_oop, prob_iip, prob_oop_vjp, prob_iip_vjp
File renamed without changes.

lib/NonlinearSolveBase/Project.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,14 @@ julia = "1.10"
8383

8484
[extras]
8585
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
86+
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
8687
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
8788
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
8889
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
8990
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
91+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
9092
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
9193
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
9294

9395
[targets]
94-
test = ["Aqua", "DiffEqBase", "ExplicitImports", "ForwardDiff", "InteractiveUtils", "SparseArrays", "Test"]
96+
test = ["Aqua", "BandedMatrices", "DiffEqBase", "ExplicitImports", "ForwardDiff", "InteractiveUtils", "LinearAlgebra", "SparseArrays", "Test"]

lib/NonlinearSolveBase/src/NonlinearSolveBase.jl

+1-2
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@ include("solve.jl")
5757
@compat(public, (L2_NORM, Linf_NORM, NAN_CHECK, UNITLESS_ABS2, get_tolerance))
5858
@compat(public, (nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution))
5959
@compat(public,
60-
(select_forward_mode_autodiff, select_reverse_mode_autodiff,
61-
select_jacobian_autodiff))
60+
(select_forward_mode_autodiff, select_reverse_mode_autodiff, select_jacobian_autodiff))
6261

6362
# public for NonlinearSolve.jl and subpackages to use
6463
@compat(public, (InternalAPI, supports_line_search, supports_trust_region, set_du!))

lib/NonlinearSolveBase/src/linear_solve.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ function (cache::NativeJLLinearSolveCache)(;
102102
b === nothing || (cache.b = b)
103103

104104
if linu !== nothing && ArrayInterface.can_setindex(linu) &&
105-
applicable(ldiv!, linu, cache.A, cache.b)
105+
applicable(ldiv!, linu, cache.A, cache.b) && applicable(ldiv!, cache.A, linu)
106106
ldiv!(linu, cache.A, cache.b)
107107
res = linu
108108
else

lib/NonlinearSolveBase/src/utils.jl

+6
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,12 @@ convert_real(::Type{T}, ::Nothing) where {T} = nothing
9191
convert_real(::Type{T}, x) where {T} = real(T(x))
9292

9393
restructure(::Number, x::Number) = x
94+
function restructure(
95+
y::T1, x::T2
96+
) where {T1 <: AbstractSciMLOperator, T2 <: AbstractSciMLOperator}
97+
@assert size(y) == size(x) "cannot restructure operators. ensure their sizes match."
98+
return x
99+
end
94100
restructure(y, x) = ArrayInterface.restructure(y, x)
95101

96102
function safe_similar(x, args...; kwargs...)

lib/NonlinearSolveBase/test/runtests.jl

+9
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,13 @@ using InteractiveUtils, Test
2424
@test check_no_stale_explicit_imports(NonlinearSolveBase) === nothing
2525
@test check_all_qualified_accesses_via_owners(NonlinearSolveBase) === nothing
2626
end
27+
28+
@testset "Banded Matrix vcat" begin
29+
using BandedMatrices, LinearAlgebra, SparseArrays
30+
31+
b = BandedMatrix(Ones(5, 5), (1, 1))
32+
d = Diagonal(ones(5, 5))
33+
34+
@test NonlinearSolveBase.Utils.faster_vcat(b, d) == vcat(sparse(b), d)
35+
end
2736
end

lib/NonlinearSolveFirstOrder/Project.toml

+22-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
1111
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
1212
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1313
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
14-
LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b"
1514
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1615
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
1716
MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
@@ -27,41 +26,61 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2726
ADTypes = "1.9.0"
2827
Aqua = "0.8"
2928
ArrayInterface = "7.16.0"
29+
BandedMatrices = "1.7.5"
30+
BenchmarkTools = "1.5.0"
3031
CommonSolve = "0.2.4"
3132
ConcreteStructs = "0.2.3"
3233
DiffEqBase = "6.155.3"
34+
Enzyme = "0.13.12"
3335
ExplicitImports = "1.5"
3436
FiniteDiff = "2.26.0"
3537
ForwardDiff = "0.10.36"
3638
Hwloc = "3"
3739
InteractiveUtils = "<0.0.1, 1"
3840
LineSearch = "0.1.4"
39-
LinearAlgebra = "1.11.0"
41+
LineSearches = "7.3.0"
42+
LinearAlgebra = "1.10"
4043
LinearSolve = "2.36.1"
4144
MaybeInplace = "0.1.4"
4245
NonlinearProblemLibrary = "0.1.2"
4346
NonlinearSolveBase = "1.1"
4447
Pkg = "1.10"
4548
PrecompileTools = "1.2"
49+
Random = "1.10"
4650
ReTestItems = "1.24"
4751
Reexport = "1"
4852
SciMLBase = "2.54"
53+
SciMLJacobianOperators = "0.1.0"
4954
Setfield = "1.1.1"
55+
SparseConnectivityTracer = "0.6.8"
56+
SparseMatrixColorings = "0.4.8"
5057
StableRNGs = "1"
58+
StaticArrays = "1.9.8"
5159
StaticArraysCore = "1.4.3"
5260
Test = "1.10"
61+
Zygote = "0.6.72"
5362
julia = "1.10"
5463

5564
[extras]
5665
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
66+
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
67+
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
68+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
5769
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
5870
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
5971
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
72+
LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b"
73+
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
6074
NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
6175
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
76+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
6277
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
78+
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
79+
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
6380
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
81+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
6482
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
83+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6584

6685
[targets]
67-
test = ["Aqua", "ExplicitImports", "Hwloc", "InteractiveUtils", "NonlinearProblemLibrary", "Pkg", "ReTestItems", "StableRNGs", "Test"]
86+
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "LineSearch", "LineSearches", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test", "Zygote"]

lib/NonlinearSolveFirstOrder/src/levenberg_marquardt.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ function LevenbergMarquardt(;
5757
autodiff,
5858
vjp_autodiff,
5959
jvp_autodiff,
60-
name = :LevenbergMarquardt
60+
name = :LevenbergMarquardt,
61+
concrete_jac = Val(true)
6162
)
6263
end
6364

lib/NonlinearSolveFirstOrder/src/pseudo_transient.jl

+1-3
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,7 @@ function InternalAPI.init(
7272
) where {F}
7373
T = promote_type(eltype(u), eltype(fu))
7474
return SwitchedEvolutionRelaxationCache(
75-
internalnorm(fu),
76-
T(inv(initial_damping)),
77-
internalnorm
75+
internalnorm(fu), T(inv(initial_damping)), internalnorm
7876
)
7977
end
8078

lib/NonlinearSolveFirstOrder/src/solve.jl

+35-32
Original file line numberDiff line numberDiff line change
@@ -89,53 +89,52 @@ end
8989
kwargs
9090
end
9191

92-
# XXX: Implement
93-
# function __reinit_internal!(
94-
# cache::GeneralizedFirstOrderAlgorithmCache{iip}, args...; p = cache.p, u0 = cache.u,
95-
# alias_u0::Bool = false, maxiters = 1000, maxtime = nothing, kwargs...) where {iip}
96-
# if iip
97-
# recursivecopy!(cache.u, u0)
98-
# cache.prob.f(cache.fu, cache.u, p)
99-
# else
100-
# cache.u = __maybe_unaliased(u0, alias_u0)
101-
# set_fu!(cache, cache.prob.f(cache.u, p))
102-
# end
103-
# cache.p = p
104-
105-
# __reinit_internal!(cache.stats)
106-
# cache.nsteps = 0
107-
# cache.maxiters = maxiters
108-
# cache.maxtime = maxtime
109-
# cache.total_time = 0.0
110-
# cache.force_stop = false
111-
# cache.retcode = ReturnCode.Default
112-
# cache.make_new_jacobian = true
113-
114-
# reset!(cache.trace)
115-
# reinit!(cache.termination_cache, get_fu(cache), get_u(cache); kwargs...)
116-
# reset_timer!(cache.timer)
117-
# end
92+
function InternalAPI.reinit_self!(
93+
cache::GeneralizedFirstOrderAlgorithmCache, args...; p = cache.p, u0 = cache.u,
94+
alias_u0::Bool = false, maxiters = 1000, maxtime = nothing, kwargs...
95+
)
96+
Utils.reinit_common!(cache, u0, p, alias_u0)
97+
98+
InternalAPI.reinit!(cache.stats)
99+
cache.nsteps = 0
100+
cache.maxiters = maxiters
101+
cache.maxtime = maxtime
102+
cache.total_time = 0.0
103+
cache.force_stop = false
104+
cache.retcode = ReturnCode.Default
105+
cache.make_new_jacobian = true
106+
107+
NonlinearSolveBase.reset!(cache.trace)
108+
SciMLBase.reinit!(
109+
cache.termination_cache, NonlinearSolveBase.get_fu(cache),
110+
NonlinearSolveBase.get_u(cache); kwargs...
111+
)
112+
NonlinearSolveBase.reset_timer!(cache.timer)
113+
return
114+
end
118115

119116
NonlinearSolveBase.@internal_caches(GeneralizedFirstOrderAlgorithmCache,
120117
:jac_cache, :descent_cache, :linesearch_cache, :trustregion_cache)
121118

122119
function SciMLBase.__init(
123-
prob::AbstractNonlinearProblem, alg::GeneralizedFirstOrderAlgorithm,
124-
args...; stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, maxiters = 1000,
120+
prob::AbstractNonlinearProblem, alg::GeneralizedFirstOrderAlgorithm, args...;
121+
stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, maxiters = 1000,
125122
abstol = nothing, reltol = nothing, maxtime = nothing,
126123
termination_condition = nothing, internalnorm = L2_NORM,
127124
linsolve_kwargs = (;), kwargs...
128125
)
129126
@set! alg.autodiff = NonlinearSolveBase.select_jacobian_autodiff(prob, alg.autodiff)
130-
@set! alg.jvp_autodiff = if alg.jvp_autodiff === nothing && alg.autodiff !== nothing &&
127+
provided_jvp_autodiff = alg.jvp_autodiff !== nothing
128+
@set! alg.jvp_autodiff = if !provided_jvp_autodiff && alg.autodiff !== nothing &&
131129
(ADTypes.mode(alg.autodiff) isa ADTypes.ForwardMode ||
132130
ADTypes.mode(alg.autodiff) isa
133131
ADTypes.ForwardOrReverseMode)
134132
NonlinearSolveBase.select_forward_mode_autodiff(prob, alg.autodiff)
135133
else
136134
NonlinearSolveBase.select_forward_mode_autodiff(prob, alg.jvp_autodiff)
137135
end
138-
@set! alg.vjp_autodiff = if alg.vjp_autodiff === nothing && alg.autodiff !== nothing &&
136+
provided_vjp_autodiff = alg.vjp_autodiff !== nothing
137+
@set! alg.vjp_autodiff = if !provided_vjp_autodiff && alg.autodiff !== nothing &&
139138
(ADTypes.mode(alg.autodiff) isa ADTypes.ReverseMode ||
140139
ADTypes.mode(alg.autodiff) isa
141140
ADTypes.ForwardOrReverseMode)
@@ -185,7 +184,7 @@ function SciMLBase.__init(
185184
error("Trust Region not supported by $(alg.descent).")
186185
trustregion_cache = InternalAPI.init(
187186
prob, alg.trustregion, prob.f, fu, u, prob.p;
188-
stats, internalnorm, kwargs...
187+
alg.vjp_autodiff, alg.jvp_autodiff, stats, internalnorm, kwargs...
189188
)
190189
globalization = Val(:TrustRegion)
191190
end
@@ -194,7 +193,11 @@ function SciMLBase.__init(
194193
NonlinearSolveBase.supports_line_search(alg.descent) ||
195194
error("Line Search not supported by $(alg.descent).")
196195
linesearch_cache = CommonSolve.init(
197-
prob, alg.linesearch, fu, u; stats, internalnorm, kwargs...
196+
prob, alg.linesearch, fu, u; stats, internalnorm,
197+
autodiff = ifelse(
198+
provided_jvp_autodiff, alg.jvp_autodiff, alg.vjp_autodiff
199+
),
200+
kwargs...
198201
)
199202
globalization = Val(:LineSearch)
200203
end

0 commit comments

Comments
 (0)