Skip to content

Commit d70ba91

Browse files
dkarraschdevmotion
andauthored
Handle matrix times matrix = vector case (#227)
* Handle matrix times matrix = vector case * add test * actually cover it * use LinearAlgebra * Apply suggestions from code review Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * avoid ambiguity, add tests * Update test/api/GradientTests.jl Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * fix --------- Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
1 parent f1f3d1f commit d70ba91

File tree

4 files changed

+42
-5
lines changed

4 files changed

+42
-5
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ReverseDiff"
22
uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
3-
version = "1.14.5"
3+
version = "1.14.6"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/derivatives/linalg/arithmetic.jl

+19-3
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,17 @@ end
270270
# a * b
271271

272272
function reverse_mul!(output, output_deriv, a, b, a_tmp, b_tmp)
273-
istracked(a) && increment_deriv!(a, mul!(a_tmp, output_deriv, transpose(value(b))))
273+
if istracked(a)
274+
if a_tmp isa AbstractVector && b isa AbstractMatrix
275+
# this branch is required for scalar-valued functions that
276+
# involve outer-products of vectors, for such functions, the target
277+
# a_temp is a vector, but when b is a matrix, we cannot multiply into a vector,
278+
# so need to reshape memory to look like matrix (see PositiveFactorizations.jl)
279+
increment_deriv!(a, mul!(reshape(a_tmp, :, 1), output_deriv, transpose(value(b))))
280+
else
281+
increment_deriv!(a, mul!(a_tmp, output_deriv, transpose(value(b))))
282+
end
283+
end
274284
istracked(b) && increment_deriv!(b, mul!(b_tmp, transpose(value(a)), output_deriv))
275285
end
276286

@@ -279,8 +289,14 @@ for (f, F) in ((:transpose, :Transpose), (:adjoint, :Adjoint))
279289
# a * f(b)
280290
function reverse_mul!(output, output_deriv, a, b::$F, a_tmp, b_tmp)
281291
_b = ($f)(b)
282-
istracked(a) && increment_deriv!(a, mul!(a_tmp, output_deriv, mulargvalue(b)))
283-
istracked(_b) && increment_deriv!(_b, ($f)(mul!(b_tmp, ($f)(output_deriv), value(a))))
292+
if istracked(a)
293+
if a_tmp isa AbstractVector
294+
increment_deriv!(a, mul!(reshape(a_tmp, :, 1), output_deriv, mulargvalue(_b)))
295+
else
296+
increment_deriv!(a, mul!(a_tmp, output_deriv, mulargvalue(b)))
297+
end
298+
end
299+
istracked(_b) && increment_deriv!(_b, ($f)(mul!(($f)(b_tmp), ($f)(output_deriv), value(a))))
284300
end
285301
# f(a) * b
286302
function reverse_mul!(output, output_deriv, a::$F, b, a_tmp, b_tmp)

test/api/GradientTests.jl

+15-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module GradientTests
22

3-
using DiffTests, ForwardDiff, ReverseDiff, Test
3+
using DiffTests, ForwardDiff, ReverseDiff, Test, LinearAlgebra
44

55
include(joinpath(dirname(@__FILE__), "../utils.jl"))
66

@@ -187,6 +187,20 @@ for f in DiffTests.VECTOR_TO_NUMBER_FUNCS
187187
test_unary_gradient(f, rand(5))
188188
end
189189

190+
# PR #227
191+
norm_hermitian1(v) = (A = I - 2 * v * v'; norm(A' * A))
192+
norm_hermitian2(v) = (A = I - 2 * v * transpose(v); norm(transpose(A) * A))
193+
norm_hermitian3(v) = (A = I - 2 * v * collect(v'); norm(collect(A') * A))
194+
norm_hermitian4(v) = (A = I - 2 * v * v'; norm(transpose(A) * A))
195+
norm_hermitian5(v) = (A = I - 2 * v * transpose(v); norm(A' * A))
196+
norm_hermitian6(v) = (A = (v'v)*I - 2 * v * v'; norm(A' * A))
197+
198+
for f in (norm_hermitian1, norm_hermitian2, norm_hermitian3,
199+
norm_hermitian4, norm_hermitian5, norm_hermitian6)
200+
test_println("VECTOR_TO_NUMBER_FUNCS", f)
201+
test_unary_gradient(f, rand(5))
202+
end
203+
190204
for f in DiffTests.TERNARY_MATRIX_TO_NUMBER_FUNCS
191205
test_println("TERNARY_MATRIX_TO_NUMBER_FUNCS", f)
192206
test_ternary_gradient(f, rand(5, 5), rand(5, 5), rand(5, 5))

test/derivatives/LinAlgTests.jl

+7
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,15 @@ for f in (
223223
test_arr2num(f, x, tp)
224224
end
225225

226+
# PR #227
227+
function norm_hermitian(v)
228+
A = I - 2 * v * v'
229+
return norm(A' * A)
230+
end
231+
226232
for f in (
227233
y -> vec(y)' * Matrix{Float64}(I, length(y), length(y)) * vec(y),
234+
norm_hermitian,
228235
)
229236
test_println("Array -> Number functions", f)
230237
test_arr2num(f, x, tp, ignore_tape_length=true)

0 commit comments

Comments
 (0)