Skip to content

Commit 6650adc

Browse files
Simone Carlo Suracesimsurace
Simone Carlo Surace
authored andcommitted
Enable all tests
1 parent 72060d0 commit 6650adc

File tree

2 files changed

+166
-166
lines changed

2 files changed

+166
-166
lines changed

test/rulesets/LinearAlgebra/dense.jl

+148-148
Original file line numberDiff line numberDiff line change
@@ -1,164 +1,164 @@
11
@testset "dense LinearAlgebra" begin
2-
# @testset "dot" begin
3-
# @testset "Vector{$T}" for T in (Float64, ComplexF64)
4-
# @gpu test_frule(dot, randn(T, 3), randn(T, 3))
5-
# @gpu test_rrule(dot, randn(T, 3), randn(T, 3))
6-
# end
7-
# @testset "Array{$T, 3}" for T in (Float64, ComplexF64)
8-
# test_frule(dot, randn(T, 3, 4, 5), randn(T, 3, 4, 5))
9-
# test_rrule(dot, randn(T, 3, 4, 5), randn(T, 3, 4, 5))
10-
# end
11-
# @testset "mismatched shapes" begin
12-
# # forward
13-
# @gpu test_frule(dot, randn(3, 5), randn(5, 3))
14-
# @gpu test_frule(dot, randn(15), randn(5, 3))
15-
# # reverse
16-
# @gpu test_rrule(dot, randn(3, 5), randn(5, 3))
17-
# @gpu test_rrule(dot, randn(15), randn(5, 3))
18-
# end
19-
# @testset "3-arg dot, Array{$T}" for T in (Float64, ComplexF64)
20-
# @gpu_broken test_frule(dot, randn(T, 3), randn(T, 3, 4), randn(T, 4))
21-
# @gpu test_rrule(dot, randn(T, 3), randn(T, 3, 4), randn(T, 4))
22-
# end
23-
# permuteddimsarray(A) = PermutedDimsArray(A, (2,1))
24-
# @testset "3-arg dot, $F{$T}" for T in (Float32, ComplexF32), F in (adjoint, permuteddimsarray)
25-
# A = F(rand(T, 4, 3)) ⊢ F(rand(T, 4, 3))
26-
# test_frule(dot, rand(T, 3), A, rand(T, 4); rtol=1f-3)
27-
# test_rrule(dot, rand(T, 3), A, rand(T, 4); rtol=1f-3)
28-
# end
29-
# @testset "different types" begin
30-
# test_rrule(dot, rand(2), rand(2, 2), rand(ComplexF64, 2))
31-
# test_rrule(dot, rand(2), Diagonal(rand(2)), rand(ComplexF64, 2))
2+
@testset "dot" begin
3+
@testset "Vector{$T}" for T in (Float64, ComplexF64)
4+
@gpu test_frule(dot, randn(T, 3), randn(T, 3))
5+
@gpu test_rrule(dot, randn(T, 3), randn(T, 3))
6+
end
7+
@testset "Array{$T, 3}" for T in (Float64, ComplexF64)
8+
test_frule(dot, randn(T, 3, 4, 5), randn(T, 3, 4, 5))
9+
test_rrule(dot, randn(T, 3, 4, 5), randn(T, 3, 4, 5))
10+
end
11+
@testset "mismatched shapes" begin
12+
# forward
13+
@gpu test_frule(dot, randn(3, 5), randn(5, 3))
14+
@gpu test_frule(dot, randn(15), randn(5, 3))
15+
# reverse
16+
@gpu test_rrule(dot, randn(3, 5), randn(5, 3))
17+
@gpu test_rrule(dot, randn(15), randn(5, 3))
18+
end
19+
@testset "3-arg dot, Array{$T}" for T in (Float64, ComplexF64)
20+
@gpu_broken test_frule(dot, randn(T, 3), randn(T, 3, 4), randn(T, 4))
21+
@gpu test_rrule(dot, randn(T, 3), randn(T, 3, 4), randn(T, 4))
22+
end
23+
permuteddimsarray(A) = PermutedDimsArray(A, (2,1))
24+
@testset "3-arg dot, $F{$T}" for T in (Float32, ComplexF32), F in (adjoint, permuteddimsarray)
25+
A = F(rand(T, 4, 3)) F(rand(T, 4, 3))
26+
test_frule(dot, rand(T, 3), A, rand(T, 4); rtol=1f-3)
27+
test_rrule(dot, rand(T, 3), A, rand(T, 4); rtol=1f-3)
28+
end
29+
@testset "different types" begin
30+
test_rrule(dot, rand(2), rand(2, 2), rand(ComplexF64, 2))
31+
test_rrule(dot, rand(2), Diagonal(rand(2)), rand(ComplexF64, 2))
3232

33-
# # Inference failure due to https://github.com/JuliaDiff/ChainRulesCore.jl/issues/407
34-
# test_rrule(dot, Diagonal(rand(2)), rand(2, 2); check_inferred=false)
35-
# end
36-
# end
33+
# Inference failure due to https://github.com/JuliaDiff/ChainRulesCore.jl/issues/407
34+
test_rrule(dot, Diagonal(rand(2)), rand(2, 2); check_inferred=false)
35+
end
36+
end
3737

38-
# @testset "mul!" begin
39-
# test_frule(mul!, rand(4), rand(4, 5), rand(5))
40-
# test_frule(mul!, rand(3, 3), rand(3, 3), rand(3, 3))
41-
# test_frule(mul!, rand(3, 3), rand(), rand(3, 3))
38+
@testset "mul!" begin
39+
test_frule(mul!, rand(4), rand(4, 5), rand(5))
40+
test_frule(mul!, rand(3, 3), rand(3, 3), rand(3, 3))
41+
test_frule(mul!, rand(3, 3), rand(), rand(3, 3))
4242

43-
# # Rule with α,β::Bool is only visually more complicated:
44-
# test_frule(mul!, rand(4), rand(4, 5), rand(5), true, true)
45-
# test_frule(mul!, rand(4), rand(4, 5), rand(5), false, true)
46-
# test_frule(mul!, rand(4), rand(4, 5), rand(5), true, false)
47-
# test_frule(mul!, rand(4), rand(4, 5), rand(5), false, false)
43+
# Rule with α,β::Bool is only visually more complicated:
44+
test_frule(mul!, rand(4), rand(4, 5), rand(5), true, true)
45+
test_frule(mul!, rand(4), rand(4, 5), rand(5), false, true)
46+
test_frule(mul!, rand(4), rand(4, 5), rand(5), true, false)
47+
test_frule(mul!, rand(4), rand(4, 5), rand(5), false, false)
4848

49-
# # Rule with nontrivial α, β allocates A*B:
50-
# test_frule(mul!, rand(4), rand(4, 5), rand(5), true, randn())
51-
# test_frule(mul!, rand(4), rand(4, 5), rand(5), randn(), randn())
52-
# end
49+
# Rule with nontrivial α, β allocates A*B:
50+
test_frule(mul!, rand(4), rand(4, 5), rand(5), true, randn())
51+
test_frule(mul!, rand(4), rand(4, 5), rand(5), randn(), randn())
52+
end
5353

54-
# @testset "cross" begin
55-
# test_frule(cross, randn(3), randn(3))
56-
# test_frule(cross, randn(ComplexF64, 3), randn(ComplexF64, 3))
57-
# test_rrule(cross, randn(3), randn(3))
58-
# # No complex support for rrule(cross,...
54+
@testset "cross" begin
55+
test_frule(cross, randn(3), randn(3))
56+
test_frule(cross, randn(ComplexF64, 3), randn(ComplexF64, 3))
57+
test_rrule(cross, randn(3), randn(3))
58+
# No complex support for rrule(cross,...
5959

60-
# # mix types
61-
# test_rrule(cross, rand(3), rand(Float32, 3); rtol = 1.0e-7, atol = 1.0e-7)
62-
# end
63-
# @testset "pinv" begin
64-
# @testset "$T" for T in (Float64, ComplexF64)
65-
# test_scalar(pinv, randn(T))
66-
# @test frule((ZeroTangent(), randn(T)), pinv, zero(T))[2] ≈ zero(T)
67-
# @test rrule(pinv, zero(T))[2](randn(T))[2] ≈ zero(T)
68-
# end
69-
# @testset "Vector{$T}" for T in (Float64, ComplexF64)
70-
# test_frule(pinv, randn(T, 3), 0.0)
71-
# test_frule(pinv, randn(T, 3), 0.0)
60+
# mix types
61+
test_rrule(cross, rand(3), rand(Float32, 3); rtol = 1.0e-7, atol = 1.0e-7)
62+
end
63+
@testset "pinv" begin
64+
@testset "$T" for T in (Float64, ComplexF64)
65+
test_scalar(pinv, randn(T))
66+
@test frule((ZeroTangent(), randn(T)), pinv, zero(T))[2] zero(T)
67+
@test rrule(pinv, zero(T))[2](randn(T))[2] zero(T)
68+
end
69+
@testset "Vector{$T}" for T in (Float64, ComplexF64)
70+
test_frule(pinv, randn(T, 3), 0.0)
71+
test_frule(pinv, randn(T, 3), 0.0)
7272

73-
# # Checking types. TODO do we still need this?
74-
# x = randn(T, 3)
75-
# ẋ = randn(T, 3)
76-
# Δy = copyto!(similar(pinv(x)), randn(T, 3))
77-
# @test frule((ZeroTangent(), ẋ), pinv, x)[2] isa typeof(pinv(x))
78-
# @test rrule(pinv, x)[2](Δy)[2] isa typeof(x)
79-
# end
73+
# Checking types. TODO do we still need this?
74+
x = randn(T, 3)
75+
= randn(T, 3)
76+
Δy = copyto!(similar(pinv(x)), randn(T, 3))
77+
@test frule((ZeroTangent(), ẋ), pinv, x)[2] isa typeof(pinv(x))
78+
@test rrule(pinv, x)[2](Δy)[2] isa typeof(x)
79+
end
8080

81-
# @testset "$F{Vector{$T}}" for T in (Float64, ComplexF64), F in (Transpose, Adjoint)
82-
# test_frule(pinv, F(randn(T, 3)))
83-
# test_rrule(pinv, F(randn(T, 3)))
81+
@testset "$F{Vector{$T}}" for T in (Float64, ComplexF64), F in (Transpose, Adjoint)
82+
test_frule(pinv, F(randn(T, 3)))
83+
test_rrule(pinv, F(randn(T, 3)))
8484

85-
# # Check types.
86-
# # TODO: Do we need this still?
87-
# x, ẋ, x̄ = F(randn(T, 3)), F(randn(T, 3)), F(randn(T, 3))
88-
# y = pinv(x)
89-
# Δy = copyto!(similar(y), randn(T, 3))
85+
# Check types.
86+
# TODO: Do we need this still?
87+
x, ẋ, x̄ = F(randn(T, 3)), F(randn(T, 3)), F(randn(T, 3))
88+
y = pinv(x)
89+
Δy = copyto!(similar(y), randn(T, 3))
9090

91-
# y_fwd, ∂y_fwd = frule((ZeroTangent(), ẋ), pinv, x)
92-
# @test y_fwd isa typeof(y)
93-
# @test ∂y_fwd isa typeof(y)
91+
y_fwd, ∂y_fwd = frule((ZeroTangent(), ẋ), pinv, x)
92+
@test y_fwd isa typeof(y)
93+
@test ∂y_fwd isa typeof(y)
9494

95-
# y_rev, back = rrule(pinv, x)
96-
# @test y_rev isa typeof(y)
97-
# @test back(Δy)[2] isa typeof(x)
98-
# end
99-
# @testset "Matrix{$T} with size ($m,$n)" for T in (Float64, ComplexF64),
100-
# m in 1:3,
101-
# n in 1:3
95+
y_rev, back = rrule(pinv, x)
96+
@test y_rev isa typeof(y)
97+
@test back(Δy)[2] isa typeof(x)
98+
end
99+
@testset "Matrix{$T} with size ($m,$n)" for T in (Float64, ComplexF64),
100+
m in 1:3,
101+
n in 1:3
102102

103-
# test_frule(pinv, randn(T, m, n))
104-
# test_rrule(pinv, randn(T, m, n))
105-
# end
106-
# end
107-
# @testset "$f" for f in (det, logdet)
108-
# @testset "$f(::$T)" for T in (Float64, ComplexF64)
109-
# b = (f === logdet && T <: Real) ? abs(randn(T)) : randn(T)
110-
# test_scalar(f, b)
111-
# end
112-
# @testset "$f(::Matrix{$T})" for T in (Float64, ComplexF64)
113-
# B = generate_well_conditioned_matrix(T, 4)
114-
# if f === logdet && float(T) <: Float32
115-
# test_frule(f, B; atol=1e-5, rtol=1e-5)
116-
# test_rrule(f, B; atol=1e-5, rtol=1e-5)
117-
# else
118-
# test_frule(f, B)
119-
# test_rrule(f, B)
120-
# end
121-
# end
122-
# @testset "$f(complex determinant)" begin
123-
# B = randn(ComplexF64, 4, 4)
124-
# U = exp(B - B')
125-
# test_frule(f, U)
126-
# test_rrule(f, U)
127-
# end
128-
# @testset "gpu" begin
129-
# @gpu_broken test_rrule(f, reshape(1:9, 3, 3)+I*pi)
130-
# end
131-
# end
132-
# @testset "logabsdet(::Matrix{$T})" for T in (Float64, ComplexF64)
133-
# B = randn(T, 4, 4)
134-
# test_frule(logabsdet, B)
135-
# test_rrule(logabsdet, B)
136-
# # test for opposite sign of determinant
137-
# test_frule(logabsdet, -B)
138-
# test_rrule(logabsdet, -B)
139-
# end
140-
# @testset "tr" begin
141-
# @gpu test_frule(tr, randn(4, 4))
142-
# @gpu test_rrule(tr, randn(4, 4))
143-
# end
144-
# @testset "sylvester" begin
145-
# @testset "T=$T, m=$m, n=$n" for T in (Float64, ComplexF64), m in (2, 3), n in (1, 3)
146-
# A = randn(T, m, m)
147-
# B = randn(T, n, n)
148-
# C = randn(T, m, n)
149-
# test_frule(sylvester, A, B, C)
150-
# test_rrule(sylvester, A, B, C)
151-
# end
152-
# end
153-
# @testset "lyap" begin
154-
# n = 3
155-
# @testset "Float64" for T in (Float64, ComplexF64)
156-
# A = randn(T, n, n)
157-
# C = randn(T, n, n)
158-
# test_frule(lyap, A, C)
159-
# test_rrule(lyap, A, C)
160-
# end
161-
# end
103+
test_frule(pinv, randn(T, m, n))
104+
test_rrule(pinv, randn(T, m, n))
105+
end
106+
end
107+
@testset "$f" for f in (det, logdet)
108+
@testset "$f(::$T)" for T in (Float64, ComplexF64)
109+
b = (f === logdet && T <: Real) ? abs(randn(T)) : randn(T)
110+
test_scalar(f, b)
111+
end
112+
@testset "$f(::Matrix{$T})" for T in (Float64, ComplexF64)
113+
B = generate_well_conditioned_matrix(T, 4)
114+
if f === logdet && float(T) <: Float32
115+
test_frule(f, B; atol=1e-5, rtol=1e-5)
116+
test_rrule(f, B; atol=1e-5, rtol=1e-5)
117+
else
118+
test_frule(f, B)
119+
test_rrule(f, B)
120+
end
121+
end
122+
@testset "$f(complex determinant)" begin
123+
B = randn(ComplexF64, 4, 4)
124+
U = exp(B - B')
125+
test_frule(f, U)
126+
test_rrule(f, U)
127+
end
128+
@testset "gpu" begin
129+
@gpu_broken test_rrule(f, reshape(1:9, 3, 3)+I*pi)
130+
end
131+
end
132+
@testset "logabsdet(::Matrix{$T})" for T in (Float64, ComplexF64)
133+
B = randn(T, 4, 4)
134+
test_frule(logabsdet, B)
135+
test_rrule(logabsdet, B)
136+
# test for opposite sign of determinant
137+
test_frule(logabsdet, -B)
138+
test_rrule(logabsdet, -B)
139+
end
140+
@testset "tr" begin
141+
@gpu test_frule(tr, randn(4, 4))
142+
@gpu test_rrule(tr, randn(4, 4))
143+
end
144+
@testset "sylvester" begin
145+
@testset "T=$T, m=$m, n=$n" for T in (Float64, ComplexF64), m in (2, 3), n in (1, 3)
146+
A = randn(T, m, m)
147+
B = randn(T, n, n)
148+
C = randn(T, m, n)
149+
test_frule(sylvester, A, B, C)
150+
test_rrule(sylvester, A, B, C)
151+
end
152+
end
153+
@testset "lyap" begin
154+
n = 3
155+
@testset "Float64" for T in (Float64, ComplexF64)
156+
A = randn(T, n, n)
157+
C = randn(T, n, n)
158+
test_frule(lyap, A, C)
159+
test_rrule(lyap, A, C)
160+
end
161+
end
162162
VERSION v"1.9.0" && @testset "kron" begin
163163
@testset "AbstractVecOrMat{$T1}, AbstractVecOrMat{$T2}" for T1 in (Float64, ComplexF64), T2 in (Float64, ComplexF64)
164164
@testset "frule" begin

test/runtests.jl

+18-18
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ end
5050
include("test_helpers.jl") # This can't be skipped
5151
println()
5252

53-
# test_method_tables() # Check the global method tables are consistent
53+
test_method_tables() # Check the global method tables are consistent
5454

5555
# Each file puts all tests inside one or more @testset blocks
5656
include_test("rulesets/Base/CoreLogging.jl")
@@ -64,30 +64,30 @@ end
6464
include_test("rulesets/Base/sort.jl")
6565
include_test("rulesets/Base/broadcast.jl")
6666

67-
# include_test("unzipped.jl") # used primarily for broadcast
67+
include_test("unzipped.jl") # used primarily for broadcast
6868

69-
# println()
69+
println()
7070

71-
# include_test("rulesets/Statistics/statistics.jl")
71+
include_test("rulesets/Statistics/statistics.jl")
7272

73-
# println()
73+
println()
7474

7575
include_test("rulesets/LinearAlgebra/dense.jl")
76-
# include_test("rulesets/LinearAlgebra/norm.jl")
77-
# include_test("rulesets/LinearAlgebra/matfun.jl")
78-
# include_test("rulesets/LinearAlgebra/structured.jl")
79-
# include_test("rulesets/LinearAlgebra/symmetric.jl")
80-
# include_test("rulesets/LinearAlgebra/factorization.jl")
81-
# include_test("rulesets/LinearAlgebra/blas.jl")
82-
# include_test("rulesets/LinearAlgebra/lapack.jl")
83-
# include_test("rulesets/LinearAlgebra/uniformscaling.jl")
76+
include_test("rulesets/LinearAlgebra/norm.jl")
77+
include_test("rulesets/LinearAlgebra/matfun.jl")
78+
include_test("rulesets/LinearAlgebra/structured.jl")
79+
include_test("rulesets/LinearAlgebra/symmetric.jl")
80+
include_test("rulesets/LinearAlgebra/factorization.jl")
81+
include_test("rulesets/LinearAlgebra/blas.jl")
82+
include_test("rulesets/LinearAlgebra/lapack.jl")
83+
include_test("rulesets/LinearAlgebra/uniformscaling.jl")
8484

85-
# println()
85+
println()
8686

87-
# include_test("rulesets/SparseArrays/sparsematrix.jl")
87+
include_test("rulesets/SparseArrays/sparsematrix.jl")
8888

89-
# println()
89+
println()
9090

91-
# include_test("rulesets/Random/random.jl")
92-
# println()
91+
include_test("rulesets/Random/random.jl")
92+
println()
9393
end

0 commit comments

Comments
 (0)