Skip to content

Commit 9d4352e

Browse files
committed
update tests, CR version
1 parent 805dcf9 commit 9d4352e

File tree

3 files changed

+10
-7
lines changed

3 files changed

+10
-7
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1313
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
1414

1515
[compat]
16-
ChainRules = "1.5"
17-
ChainRulesCore = "1.2"
16+
ChainRules = "1.44.6"
17+
ChainRulesCore = "1.15.3"
1818
Combinatorics = "1"
1919
StaticArrays = "1"
2020
StatsBase = "0.33"

src/extra_rules.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,5 +266,4 @@ function ChainRulesCore.rrule(::DiffractorRuleConfig, ::Type{InplaceableThunk},
266266
val, Δ->(NoTangent(), NoTangent(), Δ)
267267
end
268268

269-
Base.real(z::ZeroTangent) = z # TODO should be in CRC
270-
Base.real(z::NoTangent) = z
269+
Base.real(z::NoTangent) = z # TODO should be in CRC, https://github.com/JuliaDiff/ChainRulesCore.jl/pull/581

test/runtests.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -265,13 +265,17 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2)
265265
end
266266

267267
@testset "broadcast, 2nd order" begin
268-
@test gradient(x -> sum(gradient(x -> sum(x .^ 2 .+ x'), x)[1]), [1,2,3.0])[1] == [6,6,6]
269-
@test gradient(x -> sum(gradient(x -> sum((x .+ 1) .* x .- x), x)[1]), [1,2,3.0])[1] == [2,2,2]
268+
@test gradient(x -> gradient(y -> sum(y .* y), x)[1] |> sum, [1,2,3.0])[1] == [2,2,2] # calls "split broadcasting generic" with f = unthunk
269+
@test gradient(x -> gradient(y -> sum(y .* 2 .* y'), x)[1] |> sum, [1,2,3.0])[1] == [12, 12, 12] broken=VERSION<v"1.8" # On 1.7: Control flow support not fully implemented yet for higher-order
270+
@test gradient(x -> gradient(y -> sum(y .* x), x)[1].^3 |> sum, [1,2,3.0])[1] == [3,12,27]
271+
272+
@test_broken gradient(x -> sum(gradient(x -> sum(x .^ 2 .+ x'), x)[1]), [1,2,3.0])[1] == [6,6,6] # BoundsError: attempt to access 18-element Vector{Core.Compiler.BasicBlock} at index [0]
273+
@test_broken gradient(x -> sum(gradient(x -> sum((x .+ 1) .* x .- x), x)[1]), [1,2,3.0])[1] == [2,2,2]
270274
@test_broken gradient(x -> sum(gradient(x -> sum(x .* x ./ 2), x)[1]), [1,2,3.0])[1] == [1,1,1]
271275

272276
@test_broken gradient(x -> sum(gradient(x -> sum(exp.(x)), x)[1]), [1,2,3])[1] exp.(1:3) # MethodError: no method matching copy(::Nothing)
273277
@test_broken gradient(x -> sum(gradient(x -> sum(atan.(x, x')), x)[1]), [1,2,3.0])[1] [0,0,0]
274-
@test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) .* x), x)[1]), [1,2,3]) == ([6,6,6],) # ERROR: (1, current_logger_for_env(std_level::Base.CoreLogging.LogLevel, group, _module) @ Base.CoreLogging logging.jl:500, :($(Expr(:meta, :noinline))))
278+
@test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) .* x), x)[1]), [1,2,3]) == ([6,6,6],) # accum(a::Transpose{Float64, Vector{Float64}}, b::ChainRulesCore.Tangent{Transpose{Int64, Vector{Int64}}, NamedTuple{(:parent,), Tuple{ChainRulesCore.NoTangent}}})
275279
@test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) ./ x.^2), x)[1]), [1,2,3])[1] [27.675925925925927, -0.824074074074074, -2.1018518518518516]
276280

277281
@test_broken gradient(z -> gradient(x -> sum((y -> (x^2*y)).([1,2,3])), z)[1], 5.0) == (12.0,)

0 commit comments

Comments
 (0)