Skip to content

Commit 58758f6

Browse files
authored
fix FillArrays 0.10 compat (#157)
* fix FillArrays compat * add tests * fix FillArrays in toml * define a module
1 parent 71c5ac0 commit 58758f6

File tree

4 files changed

+22
-4
lines changed

4 files changed

+22
-4
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ julia = "1"
2828

2929
[extras]
3030
DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d"
31+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
3132
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3233

3334
[targets]
34-
test = ["DiffTests", "Test"]
35+
test = ["DiffTests", "FillArrays", "Test"]

src/derivatives/broadcast.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,17 @@ function get_implementation(bc, f, T, args)
8585
end
8686
function Base.copy(_bc::Broadcasted{TrackedStyle})
8787
bc = remove_not_tracked(_bc)
88-
flattened_bc = Broadcast.flatten(bc)
88+
flattened_bc = Base.Broadcast.flatten(bc)
8989
untracked_bc = broadcast_rebuild(bc)
90-
flattened_untracked_bc = Broadcast.flatten(untracked_bc)
9190
T = Core.Compiler.return_type(copy, Tuple{typeof(untracked_bc)})
92-
f, args = flattened_untracked_bc.f, flattened_bc.args
91+
f, args = flattened_bc.f, flattened_bc.args
9392
implementation = get_implementation(_bc, f, T, args)
9493
if implementation isa Val{:reversediff}
9594
return ∇broadcast(f, args...)
9695
elseif implementation isa Val{:tracker}
9796
return tracker_∇broadcast(f, args...)
9897
else
98+
flattened_untracked_bc = Base.Broadcast.flatten(untracked_bc)
9999
style, axes = getstyle(flattened_untracked_bc), flattened_bc.axes
100100
return copy(Broadcasted{style, typeof(axes), typeof(f), typeof(args)}(f, args, axes))
101101
end

test/compat/CompatTests.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module CompatTests
2+
3+
using FillArrays, ReverseDiff, Test
4+
5+
@test ReverseDiff.gradient(fill(2.0, 3)) do x
6+
sum(abs2.(x .- Zeros(3)))
7+
end == fill(4.0, 3)
8+
9+
@test ReverseDiff.gradient(fill(2.0, 3)) do x
10+
sum(abs2.(x .- (1:3)))
11+
end == [2, 0, -2]
12+
13+
end

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,7 @@ println("done (took $t seconds).")
4545
println("running ConfigTests...")
4646
t = @elapsed include(joinpath(TESTDIR, "api/ConfigTests.jl"))
4747
println("done (took $t seconds).")
48+
49+
println("running CompatTests...")
50+
t = @elapsed include(joinpath(TESTDIR, "compat/CompatTests.jl"))
51+
println("done (took $t seconds).")

0 commit comments

Comments
 (0)