@@ -42,7 +42,8 @@ function SparseForwardDiffFunction(f, x::AbstractVector; hessian = false, jac_pa
42
42
xT = eltype (x)
43
43
if length (jac. nzval) > 0
44
44
_jac = SparseDiffTools. forwarddiff_color_jacobian (_f, x, colorvec = jac_colors, sparsity = jac, jac_prototype = xT .(jac))
45
- return copy (_jac)
45
+ project_to = ChainRulesCore. ProjectTo (jac)
46
+ return project_to (copy (_jac))
46
47
else
47
48
return sparse (Int[], Int[], xT[], size (jac)... )
48
49
end
@@ -55,7 +56,8 @@ function SparseForwardDiffFunction(f, x::AbstractVector; hessian = false, jac_pa
55
56
_J = x -> _sparsevec (J (x))
56
57
H = x -> begin
57
58
_hess = SparseDiffTools. forwarddiff_color_jacobian (_J, x, colorvec = hess_colors, sparsity = hess_pattern, jac_prototype = hess)
58
- return copy (_hess)
59
+ project_to = ChainRulesCore. ProjectTo (hess)
60
+ return project_to (copy (_hess))
59
61
end
60
62
else
61
63
T = eltype (G)
@@ -71,46 +73,6 @@ function SparseForwardDiffFunction(f, x::AbstractVector; hessian = false, jac_pa
71
73
return SparseForwardDiffFunction (f, f!, y, jac, jac_pattern, jac_colors, J, vecJ!, G, hess, hess_pattern, hess_colors, H)
72
74
end
73
75
74
- _sparsevec (x:: Real ) = [x]
75
- _sparsevec (x:: Vector ) = copy (x)
76
- _sparsevec (x:: Matrix ) = copy (vec (x))
77
- function _sparsevec (x:: SparseMatrixCSC )
78
- m, n = size (x)
79
- linear_inds = zeros (Int, length (x. nzval))
80
- count = 1
81
- for colind in 1 : length (x. colptr)- 1
82
- for ind in x. colptr[colind]: x. colptr[colind+ 1 ]- 1
83
- rowind = x. rowval[ind]
84
- val = x. nzval[ind]
85
- linear_inds[count] = rowind + (colind - 1 ) * m
86
- count += 1
87
- end
88
- end
89
- return sparsevec (linear_inds, copy (x. nzval), prod (size (x)))
90
- end
91
-
92
- # can be made more efficient using div and mod
93
- function _sparse_reshape (v:: SparseVector , m, n)
94
- if length (v. nzval) == 0
95
- return sparse (Int[], Int[], v. nzval, m, n)
96
- end
97
- ind = 1
98
- N = length (v. nzval)
99
- I = zeros (Int, N)
100
- J = zeros (Int, N)
101
- for i in 1 : m, j in 1 : n
102
- if (j - 1 ) * m + i == v. nzind[ind]
103
- I[ind] = i
104
- J[ind] = j
105
- ind += 1
106
- end
107
- if ind > N
108
- break
109
- end
110
- end
111
- return sparse (I, J, copy (v. nzval), m, n)
112
- end
113
-
114
76
(f:: SparseForwardDiffFunction )(x) = f. f (x)
115
77
function ChainRulesCore. rrule (f:: SparseForwardDiffFunction , x:: AbstractVector )
116
78
if f. H === nothing
@@ -122,9 +84,14 @@ function ChainRulesCore.rrule(f::SparseForwardDiffFunction, x::AbstractVector)
122
84
jac = J (x)
123
85
return val, Δ -> begin
124
86
if val isa Real
125
- (NoTangent (), sparse ( vec ( jac' * Δ)) )
87
+ (NoTangent (), jac' * Δ)
126
88
else
127
- (NoTangent (), jac' * sparse (vec (Δ)))
89
+ spΔ = dropzeros! (sparse (_sparsevec (copy (Δ))))
90
+ if length (spΔ. nzval) == 1
91
+ (NoTangent (), jac[spΔ. nzind[1 ], :] * spΔ. nzval[1 ])
92
+ else
93
+ (NoTangent (), jac' * spΔ)
94
+ end
128
95
end
129
96
end
130
97
end
@@ -139,46 +106,82 @@ function ChainRulesCore.frule((_, Δx), f::SparseForwardDiffFunction, x::Abstrac
139
106
if val isa Real
140
107
Δy = only (jac * Δx)
141
108
elseif val isa AbstractVector
142
- Δy = jac * sparse (vec (Δx))
109
+ spΔx = dropzeros! (sparse (_sparsevec (copy (Δx))))
110
+ if length (spΔx. nzval) == 1
111
+ Δy = jac[:, spΔx. nzind[1 ]] * spΔx. nzval[1 ]
112
+ else
113
+ Δy = jac * spΔx
114
+ end
143
115
else
144
- Δy = _sparse_reshape (jac * sparse (vec (Δx)), size (val)... )
116
+ spΔx = dropzeros! (sparse (_sparsevec (copy (Δx))))
117
+ Δy = _sparse_reshape (jac * spΔx, size (val)... )
145
118
end
146
119
project_to = ChainRulesCore. ProjectTo (val)
147
120
return val, project_to (Δy)
148
121
end
149
122
@ForwardDiff_frule (f:: SparseForwardDiffFunction )(x:: AbstractVector{<:ForwardDiff.Dual} )
150
123
151
- function sparsify (f, x... ; kwargs... )
152
- # defined in the abstractdiff.jl file
153
- flat_f, vx, unflatteny = tovecfunc (f, x... )
154
- sp_flat_f = SparseForwardDiffFunction (flat_f, vx; kwargs... )
155
- return x -> unflatteny (sp_flat_f (flatten (x)[1 ]))
124
+ struct UnflattennedFunction{F1, F2, V, U} <: Function
125
+ f:: F1
126
+ flat_f:: F2
127
+ v:: V
128
+ unflatten:: U
129
+ flatteny:: Bool
130
+ end
131
+ (f:: UnflattennedFunction )(x... ) = f. f (x... )
132
+ function NonconvexCore. tovecfunc (f:: UnflattennedFunction , x... ; flatteny = true )
133
+ @assert flatteny == f. flatteny
134
+ return f. flat_f, f. v, f. unflatten
135
+ end
136
+
137
+ function sparsify (f, x... ; flatteny = true , kwargs... )
138
+ flat_f, vx, unflatteny = tovecfunc (f, x... ; flatteny)
139
+ if length (x) == 1 && x[1 ] isa AbstractVector
140
+ flat_f = f
141
+ sp_flat_f = SparseForwardDiffFunction (flat_f, vx; kwargs... )
142
+ return UnflattennedFunction (
143
+ x -> unflatteny (sp_flat_f (x)),
144
+ sp_flat_f,
145
+ vx,
146
+ unflatteny,
147
+ flatteny,
148
+ )
149
+ else
150
+ sp_flat_f = SparseForwardDiffFunction (flat_f, vx; kwargs... )
151
+ return UnflattennedFunction (
152
+ x -> unflatteny (sp_flat_f (flatten (x)[1 ])),
153
+ sp_flat_f,
154
+ vx,
155
+ unflatteny,
156
+ flatteny,
157
+ )
158
+ end
156
159
end
157
160
158
161
function sparsify (model:: NonconvexCore.AbstractModel ; objective = true , ineq_constraints = true , eq_constraints = true , sd_constraints = true , kwargs... )
159
162
x = getmin (model)
160
163
if objective
161
- obj = NonconvexCore. Objective (sparsify (model. objective, x; kwargs... ), flags = model. objective. flags)
164
+ obj = NonconvexCore. Objective (sparsify (model. objective. f , x; kwargs... ), model . objective . multiple, model. objective. flags)
162
165
else
163
166
obj = model. objective
164
167
end
165
168
if ineq_constraints
166
169
ineq = length (model. ineq_constraints. fs) != 0 ? NonconvexCore. VectorOfFunctions (map (model. ineq_constraints. fs) do c
167
- return NonconvexCore. IneqConstraint (sparsify (c, x; kwargs... ), c. rhs, c. dim, c. flags)
170
+ return NonconvexCore. IneqConstraint (sparsify (c. f , x; kwargs... ), c. rhs, c. dim, c. flags)
168
171
end ) : NonconvexCore. VectorOfFunctions (NonconvexCore. IneqConstraint[])
169
172
else
170
173
ineq = model. ineq_constraints
171
174
end
172
175
if eq_constraints
173
176
eq = length (model. eq_constraints. fs) != 0 ? NonconvexCore. VectorOfFunctions (map (model. eq_constraints. fs) do c
174
- return NonconvexCore. EqConstraint (sparsify (c, x; kwargs... ), c. rhs, c. dim, c. flags)
177
+ return NonconvexCore. EqConstraint (sparsify (c. f , x; kwargs... ), c. rhs, c. dim, c. flags)
175
178
end ) : NonconvexCore. VectorOfFunctions (NonconvexCore. EqConstraint[])
176
179
else
177
180
eq = model. eq_constraints
178
181
end
179
182
if sd_constraints
180
183
sd = length (model. sd_constraints. fs) != 0 ? NonconvexCore. VectorOfFunctions (map (model. sd_constraints. fs) do c
181
- return NonconvexCore. SDConstraint (sparsify (c, x; kwargs... ), c. dim)
184
+ return NonconvexCore. SDConstraint (sparsify (c. f , x; flatteny = false , kwargs... ), c. dim)
182
185
end ) : NonconvexCore. VectorOfFunctions (NonconvexCore. SDConstraint[])
183
186
else
184
187
sd = model. sd_constraints
0 commit comments