Skip to content

Commit 805dcf9

Browse files
committed
tests
1 parent f87143b commit 805dcf9

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

src/runtime.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,4 @@ accum(x::Tangent{T}, y::Tangent) where T = _tangent(T, accum(backing(x), backing
2727

2828
_tangent(::Type{T}, z) where T = Tangent{T,typeof(z)}(z)
2929
_tangent(::Type, ::NamedTuple{()}) = NoTangent()
30+
_tangent(::Type, ::NamedTuple{<:Any, <:Tuple{Vararg{AbstractZero}}}) = NoTangent()

test/runtests.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,9 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2)
219219
@test z45 2.0
220220
@test delta45 1.0
221221

222+
# PR #82 - getindex on non-numeric arrays
223+
@test gradient(ls -> ls[1](1.), [Base.Fix1(*, 1.)])[1][1] isa Tangent{<:Base.Fix1}
224+
222225
@testset "broadcast" begin
223226
@test gradient(x -> sum(x ./ x), [1,2,3]) == ([0,0,0],) # derivatives_given_output
224227
@test gradient(x -> sum(sqrt.(atan.(x, transpose(x)))), [1,2,3])[1] [0.2338, -0.0177, -0.0661] atol=1e-3
@@ -257,13 +260,21 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2)
257260
@test tup_adj[2] [0.6666666666666666 0.5 0.4]
258261
@test tup_adj[2] isa Transpose
259262
@test gradient(x -> sum(atan.(x, (1,2,3))), Diagonal([4,5,6]))[1] isa Diagonal
263+
264+
@test gradient(x -> sum((y -> (x*y)).([1,2,3])), 4.0) == (6.0,) # closure
260265
end
261266

262267
@testset "broadcast, 2nd order" begin
268+
@test gradient(x -> sum(gradient(x -> sum(x .^ 2 .+ x'), x)[1]), [1,2,3.0])[1] == [6,6,6]
269+
@test gradient(x -> sum(gradient(x -> sum((x .+ 1) .* x .- x), x)[1]), [1,2,3.0])[1] == [2,2,2]
270+
@test_broken gradient(x -> sum(gradient(x -> sum(x .* x ./ 2), x)[1]), [1,2,3.0])[1] == [1,1,1]
271+
263272
@test_broken gradient(x -> sum(gradient(x -> sum(exp.(x)), x)[1]), [1,2,3])[1] exp.(1:3) # MethodError: no method matching copy(::Nothing)
264-
@test_broken gradient(x -> sum(gradient(x -> sum(exp.(x)), x)[1]), [1,2,3.0])[1] exp.(1:3)
273+
@test_broken gradient(x -> sum(gradient(x -> sum(atan.(x, x')), x)[1]), [1,2,3.0])[1] [0,0,0]
265274
@test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) .* x), x)[1]), [1,2,3]) == ([6,6,6],) # ERROR: (1, current_logger_for_env(std_level::Base.CoreLogging.LogLevel, group, _module) @ Base.CoreLogging logging.jl:500, :($(Expr(:meta, :noinline))))
266275
@test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) ./ x.^2), x)[1]), [1,2,3])[1] [27.675925925925927, -0.824074074074074, -2.1018518518518516]
276+
277+
@test_broken gradient(z -> gradient(x -> sum((y -> (x^2*y)).([1,2,3])), z)[1], 5.0) == (12.0,)
267278
end
268279

269280
# Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24)

0 commit comments

Comments
 (0)