@@ -219,6 +219,9 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2)
219
219
@test z45 ≈ 2.0
220
220
@test delta45 ≈ 1.0
221
221
222
+ # PR #82 - getindex on non-numeric arrays
223
+ @test gradient (ls -> ls[1 ](1. ), [Base. Fix1 (* , 1. )])[1 ][1 ] isa Tangent{<: Base.Fix1 }
224
+
222
225
@testset " broadcast" begin
223
226
@test gradient (x -> sum (x ./ x), [1 ,2 ,3 ]) == ([0 ,0 ,0 ],) # derivatives_given_output
224
227
@test gradient (x -> sum (sqrt .(atan .(x, transpose (x)))), [1 ,2 ,3 ])[1 ] ≈ [0.2338 , - 0.0177 , - 0.0661 ] atol= 1e-3
@@ -257,13 +260,21 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2)
257
260
@test tup_adj[2 ] ≈ [0.6666666666666666 0.5 0.4 ]
258
261
@test tup_adj[2 ] isa Transpose
259
262
@test gradient (x -> sum (atan .(x, (1 ,2 ,3 ))), Diagonal ([4 ,5 ,6 ]))[1 ] isa Diagonal
263
+
264
+ @test gradient (x -> sum ((y -> (x* y)). ([1 ,2 ,3 ])), 4.0 ) == (6.0 ,) # closure
260
265
end
261
266
262
267
@testset " broadcast, 2nd order" begin
268
+ @test gradient (x -> sum (gradient (x -> sum (x .^ 2 .+ x' ), x)[1 ]), [1 ,2 ,3.0 ])[1 ] == [6 ,6 ,6 ]
269
+ @test gradient (x -> sum (gradient (x -> sum ((x .+ 1 ) .* x .- x), x)[1 ]), [1 ,2 ,3.0 ])[1 ] == [2 ,2 ,2 ]
270
+ @test_broken gradient (x -> sum (gradient (x -> sum (x .* x ./ 2 ), x)[1 ]), [1 ,2 ,3.0 ])[1 ] == [1 ,1 ,1 ]
271
+
263
272
@test_broken gradient (x -> sum (gradient (x -> sum (exp .(x)), x)[1 ]), [1 ,2 ,3 ])[1 ] ≈ exp .(1 : 3 ) # MethodError: no method matching copy(::Nothing)
264
- @test_broken gradient (x -> sum (gradient (x -> sum (exp .(x)), x)[1 ]), [1 ,2 ,3.0 ])[1 ] ≈ exp .( 1 : 3 )
273
+ @test_broken gradient (x -> sum (gradient (x -> sum (atan .(x, x ' )), x)[1 ]), [1 ,2 ,3.0 ])[1 ] ≈ [ 0 , 0 , 0 ]
265
274
@test_broken gradient (x -> sum (gradient (x -> sum (transpose (x) .* x), x)[1 ]), [1 ,2 ,3 ]) == ([6 ,6 ,6 ],) # ERROR: (1, current_logger_for_env(std_level::Base.CoreLogging.LogLevel, group, _module) @ Base.CoreLogging logging.jl:500, :($(Expr(:meta, :noinline))))
266
275
@test_broken gradient (x -> sum (gradient (x -> sum (transpose (x) ./ x.^ 2 ), x)[1 ]), [1 ,2 ,3 ])[1 ] ≈ [27.675925925925927 , - 0.824074074074074 , - 2.1018518518518516 ]
276
+
277
+ @test_broken gradient (z -> gradient (x -> sum ((y -> (x^ 2 * y)). ([1 ,2 ,3 ])), z)[1 ], 5.0 ) == (12.0 ,)
267
278
end
268
279
269
280
# Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24)
0 commit comments