Skip to content

Commit 80e82c7

Browse files
committed
sparsity improvements
1 parent 1f8a4ee commit 80e82c7

File tree

8 files changed

+192
-121
lines changed

8 files changed

+192
-121
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonconvexUtils"
22
uuid = "c48e48a2-1f5e-44ff-8799-c8e168d11d1b"
33
authors = ["Mohamed Tarek <mohamed82008@gmail.com> and contributors"]
4-
version = "0.3.0"
4+
version = "0.4.0"
55

66
[deps]
77
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
@@ -24,7 +24,7 @@ ForwardDiff = "0.10"
2424
IterativeSolvers = "0.8, 0.9"
2525
LinearMaps = "3"
2626
MacroTools = "0.5"
27-
NonconvexCore = "1.0.8"
27+
NonconvexCore = "1.1"
2828
SparseDiffTools = "1.24"
2929
Symbolics = "4.6"
3030
Zygote = "0.5, 0.6"

src/NonconvexUtils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ export forwarddiffy,
1313

1414
using ChainRulesCore, AbstractDifferentiation, ForwardDiff, LinearAlgebra
1515
using Zygote, LinearMaps, IterativeSolvers, NonconvexCore, SparseArrays
16-
using NonconvexCore: flatten
16+
using NonconvexCore: flatten, tovecfunc, _sparsevec, _sparse_reshape
1717
using MacroTools
1818
using Symbolics: Symbolics
1919
using SparseDiffTools: SparseDiffTools

src/abstractdiff.jl

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,10 @@ function ChainRulesCore.frule(
1818
end
1919
@ForwardDiff_frule (f::AbstractDiffFunction)(x::AbstractVector{<:ForwardDiff.Dual})
2020

21-
function tovecfunc(f, x...)
22-
vx, _unflattenx = flatten(x)
23-
unflattenx = NonconvexCore.Unflatten(x, _unflattenx)
24-
y = f(x...)
25-
tmp = NonconvexCore.maybeflatten(y)
26-
# should be addressed in maybeflatten
27-
if y isa Real
28-
unflatteny = identity
29-
else
30-
unflatteny = NonconvexCore.Unflatten(y, tmp[2])
31-
end
32-
return x -> NonconvexCore.maybeflatten(f(unflattenx(x)...))[1], float.(vx), unflatteny
33-
end
34-
3521
# does not assume vector input and output
3622
forwarddiffy(f_or_m, x...) = abstractdiffy(f_or_m, AD.ForwardDiffBackend(), x...)
3723
function abstractdiffy(f, backend, x...)
38-
flat_f, vx, unflatteny = tovecfunc(f, x...)
24+
flat_f, _, unflatteny = tovecfunc(f, x...)
3925
ad_flat_f = AbstractDiffFunction(flat_f, backend)
4026
return (x...,) -> unflatteny(ad_flat_f(flatten(x)[1]))
4127
end

src/custom.jl

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,28 @@ struct CustomGradFunction{F, G} <: Function
44
end
55
(f::CustomGradFunction)(x::AbstractVector) = f.f(x)
66
function ChainRulesCore.rrule(f::CustomGradFunction, x::AbstractVector)
7-
return f.f(x), Δ -> begin
8-
G = f.g(x)
7+
v = f.f(x)
8+
return v, Δ -> begin
9+
if f.g === nothing
10+
if v isa Real
11+
G = spzeros(eltype(v), length(x))
12+
else
13+
G = spzeros(eltype(v), length(v), length(x))
14+
end
15+
else
16+
G = f.g(x)
17+
end
918
if G isa AbstractVector
1019
return (NoTangent(), G * Δ)
11-
else
20+
elseif G isa LazyJacobian
1221
return (NoTangent(), G' * Δ)
22+
else
23+
spΔ = dropzeros!(sparse(copy(Δ)))
24+
if length(spΔ.nzval) == 1
25+
return (NoTangent(), G[spΔ.nzind[1], :] * spΔ.nzval[1])
26+
else
27+
return (NoTangent(), G' * Δ)
28+
end
1329
end
1430
end
1531
end
@@ -19,24 +35,25 @@ function ChainRulesCore.frule(
1935
v = f.f(x)
2036
if f.g === nothing
2137
if v isa Real
22-
= zeros(eltype(v), length(x))'
38+
= spzeros(eltype(v), 1, length(x))
2339
else
24-
= zeros(eltype(v), length(v), length(x))
40+
= spzeros(eltype(v), length(v), length(x))
2541
end
2642
else
2743
= f.g(x)
2844
end
45+
project_to = ProjectTo(v)
2946
ifisa AbstractVector && Δx isa AbstractVector
3047
if !(∇ isa LazyJacobian) && issparse(∇) && nnz(∇) == 0
31-
return v, zero(eltype(Δx))
48+
return v, project_to(zero(eltype(Δx)))
3249
else
33-
return v, ∇' * Δx
50+
return v, project_to(' * Δx)
3451
end
3552
else
3653
if !(∇ isa LazyJacobian) && issparse(∇) && nnz(∇) == 0
37-
return v, zeros(eltype(Δx), size(∇, 1))
54+
return v, project_to(spzeros(eltype(Δx), size(∇, 1)))
3855
else
39-
return v, ∇ * Δx
56+
return v, project_to(_sparse_reshape(* Δx, size(v)...))
4057
end
4158
end
4259
end
@@ -64,10 +81,11 @@ function ChainRulesCore.frule(
6481
)
6582
g = CustomGradFunction(f.g, f.h)
6683
v, ∇ = f(x), g(x)
84+
project_to = ProjectTo(v)
6785
ifisa AbstractVector && Δx isa AbstractVector
68-
return v, ∇' * Δx
86+
return v, project_to(' * Δx)
6987
else
70-
return v, ∇ * Δx
88+
return v, project_to(* Δx)
7189
end
7290
end
7391
@ForwardDiff_frule (f::CustomHessianFunction)(x::AbstractVector{<:ForwardDiff.Dual})

src/sparse_forwarddiff.jl

Lines changed: 58 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ function SparseForwardDiffFunction(f, x::AbstractVector; hessian = false, jac_pa
4242
xT = eltype(x)
4343
if length(jac.nzval) > 0
4444
_jac = SparseDiffTools.forwarddiff_color_jacobian(_f, x, colorvec = jac_colors, sparsity = jac, jac_prototype = xT.(jac))
45-
return copy(_jac)
45+
project_to = ChainRulesCore.ProjectTo(jac)
46+
return project_to(copy(_jac))
4647
else
4748
return sparse(Int[], Int[], xT[], size(jac)...)
4849
end
@@ -55,7 +56,8 @@ function SparseForwardDiffFunction(f, x::AbstractVector; hessian = false, jac_pa
5556
_J = x -> _sparsevec(J(x))
5657
H = x -> begin
5758
_hess = SparseDiffTools.forwarddiff_color_jacobian(_J, x, colorvec = hess_colors, sparsity = hess_pattern, jac_prototype = hess)
58-
return copy(_hess)
59+
project_to = ChainRulesCore.ProjectTo(hess)
60+
return project_to(copy(_hess))
5961
end
6062
else
6163
T = eltype(G)
@@ -71,46 +73,6 @@ function SparseForwardDiffFunction(f, x::AbstractVector; hessian = false, jac_pa
7173
return SparseForwardDiffFunction(f, f!, y, jac, jac_pattern, jac_colors, J, vecJ!, G, hess, hess_pattern, hess_colors, H)
7274
end
7375

74-
_sparsevec(x::Real) = [x]
75-
_sparsevec(x::Vector) = copy(x)
76-
_sparsevec(x::Matrix) = copy(vec(x))
77-
function _sparsevec(x::SparseMatrixCSC)
78-
m, n = size(x)
79-
linear_inds = zeros(Int, length(x.nzval))
80-
count = 1
81-
for colind in 1:length(x.colptr)-1
82-
for ind in x.colptr[colind]:x.colptr[colind+1]-1
83-
rowind = x.rowval[ind]
84-
val = x.nzval[ind]
85-
linear_inds[count] = rowind + (colind - 1) * m
86-
count += 1
87-
end
88-
end
89-
return sparsevec(linear_inds, copy(x.nzval), prod(size(x)))
90-
end
91-
92-
# can be made more efficient using div and mod
93-
function _sparse_reshape(v::SparseVector, m, n)
94-
if length(v.nzval) == 0
95-
return sparse(Int[], Int[], v.nzval, m, n)
96-
end
97-
ind = 1
98-
N = length(v.nzval)
99-
I = zeros(Int, N)
100-
J = zeros(Int, N)
101-
for i in 1:m, j in 1:n
102-
if (j - 1) * m + i == v.nzind[ind]
103-
I[ind] = i
104-
J[ind] = j
105-
ind += 1
106-
end
107-
if ind > N
108-
break
109-
end
110-
end
111-
return sparse(I, J, copy(v.nzval), m, n)
112-
end
113-
11476
(f::SparseForwardDiffFunction)(x) = f.f(x)
11577
function ChainRulesCore.rrule(f::SparseForwardDiffFunction, x::AbstractVector)
11678
if f.H === nothing
@@ -122,9 +84,14 @@ function ChainRulesCore.rrule(f::SparseForwardDiffFunction, x::AbstractVector)
12284
jac = J(x)
12385
return val, Δ -> begin
12486
if val isa Real
125-
(NoTangent(), sparse(vec(jac' * Δ)))
87+
(NoTangent(), jac' * Δ)
12688
else
127-
(NoTangent(), jac' * sparse(vec(Δ)))
89+
spΔ = dropzeros!(sparse(_sparsevec(copy(Δ))))
90+
if length(spΔ.nzval) == 1
91+
(NoTangent(), jac[spΔ.nzind[1], :] * spΔ.nzval[1])
92+
else
93+
(NoTangent(), jac' * spΔ)
94+
end
12895
end
12996
end
13097
end
@@ -139,46 +106,82 @@ function ChainRulesCore.frule((_, Δx), f::SparseForwardDiffFunction, x::Abstrac
139106
if val isa Real
140107
Δy = only(jac * Δx)
141108
elseif val isa AbstractVector
142-
Δy = jac * sparse(vec(Δx))
109+
spΔx = dropzeros!(sparse(_sparsevec(copy(Δx))))
110+
if length(spΔx.nzval) == 1
111+
Δy = jac[:, spΔx.nzind[1]] * spΔx.nzval[1]
112+
else
113+
Δy = jac * spΔx
114+
end
143115
else
144-
Δy = _sparse_reshape(jac * sparse(vec(Δx)), size(val)...)
116+
spΔx = dropzeros!(sparse(_sparsevec(copy(Δx))))
117+
Δy = _sparse_reshape(jac * spΔx, size(val)...)
145118
end
146119
project_to = ChainRulesCore.ProjectTo(val)
147120
return val, project_to(Δy)
148121
end
149122
@ForwardDiff_frule (f::SparseForwardDiffFunction)(x::AbstractVector{<:ForwardDiff.Dual})
150123

151-
function sparsify(f, x...; kwargs...)
152-
# defined in the abstractdiff.jl file
153-
flat_f, vx, unflatteny = tovecfunc(f, x...)
154-
sp_flat_f = SparseForwardDiffFunction(flat_f, vx; kwargs...)
155-
return x -> unflatteny(sp_flat_f(flatten(x)[1]))
124+
struct UnflattennedFunction{F1, F2, V, U} <: Function
125+
f::F1
126+
flat_f::F2
127+
v::V
128+
unflatten::U
129+
flatteny::Bool
130+
end
131+
(f::UnflattennedFunction)(x...) = f.f(x...)
132+
function NonconvexCore.tovecfunc(f::UnflattennedFunction, x...; flatteny = true)
133+
@assert flatteny == f.flatteny
134+
return f.flat_f, f.v, f.unflatten
135+
end
136+
137+
function sparsify(f, x...; flatteny = true, kwargs...)
138+
flat_f, vx, unflatteny = tovecfunc(f, x...; flatteny)
139+
if length(x) == 1 && x[1] isa AbstractVector
140+
flat_f = f
141+
sp_flat_f = SparseForwardDiffFunction(flat_f, vx; kwargs...)
142+
return UnflattennedFunction(
143+
x -> unflatteny(sp_flat_f(x)),
144+
sp_flat_f,
145+
vx,
146+
unflatteny,
147+
flatteny,
148+
)
149+
else
150+
sp_flat_f = SparseForwardDiffFunction(flat_f, vx; kwargs...)
151+
return UnflattennedFunction(
152+
x -> unflatteny(sp_flat_f(flatten(x)[1])),
153+
sp_flat_f,
154+
vx,
155+
unflatteny,
156+
flatteny,
157+
)
158+
end
156159
end
157160

158161
function sparsify(model::NonconvexCore.AbstractModel; objective = true, ineq_constraints = true, eq_constraints = true, sd_constraints = true, kwargs...)
159162
x = getmin(model)
160163
if objective
161-
obj = NonconvexCore.Objective(sparsify(model.objective, x; kwargs...), flags = model.objective.flags)
164+
obj = NonconvexCore.Objective(sparsify(model.objective.f, x; kwargs...), model.objective.multiple, model.objective.flags)
162165
else
163166
obj = model.objective
164167
end
165168
if ineq_constraints
166169
ineq = length(model.ineq_constraints.fs) != 0 ? NonconvexCore.VectorOfFunctions(map(model.ineq_constraints.fs) do c
167-
return NonconvexCore.IneqConstraint(sparsify(c, x; kwargs...), c.rhs, c.dim, c.flags)
170+
return NonconvexCore.IneqConstraint(sparsify(c.f, x; kwargs...), c.rhs, c.dim, c.flags)
168171
end) : NonconvexCore.VectorOfFunctions(NonconvexCore.IneqConstraint[])
169172
else
170173
ineq = model.ineq_constraints
171174
end
172175
if eq_constraints
173176
eq = length(model.eq_constraints.fs) != 0 ? NonconvexCore.VectorOfFunctions(map(model.eq_constraints.fs) do c
174-
return NonconvexCore.EqConstraint(sparsify(c, x; kwargs...), c.rhs, c.dim, c.flags)
177+
return NonconvexCore.EqConstraint(sparsify(c.f, x; kwargs...), c.rhs, c.dim, c.flags)
175178
end) : NonconvexCore.VectorOfFunctions(NonconvexCore.EqConstraint[])
176179
else
177180
eq = model.eq_constraints
178181
end
179182
if sd_constraints
180183
sd = length(model.sd_constraints.fs) != 0 ? NonconvexCore.VectorOfFunctions(map(model.sd_constraints.fs) do c
181-
return NonconvexCore.SDConstraint(sparsify(c, x; kwargs...), c.dim)
184+
return NonconvexCore.SDConstraint(sparsify(c.f, x; flatteny = false, kwargs...), c.dim)
182185
end) : NonconvexCore.VectorOfFunctions(NonconvexCore.SDConstraint[])
183186
else
184187
sd = model.sd_constraints

0 commit comments

Comments
 (0)