Skip to content

Commit d7ebc3d

Browse files
tkfjrevels
authored andcommitted
Enable SIMD with Base.literal_pow (#332)
1 parent 9fe228c commit d7ebc3d

File tree

4 files changed

+48
-9
lines changed

4 files changed

+48
-9
lines changed

.travis.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ notifications:
1010
sudo: false
1111
script:
1212
- if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
13-
- julia -e 'Pkg.clone(pwd()); Pkg.build("ForwardDiff"); Pkg.test("ForwardDiff"; coverage=true)';
14-
- julia -O3 -e 'include("test/SIMDTest.jl")';
13+
- julia --color=yes -e 'Pkg.clone(pwd()); Pkg.build("ForwardDiff"); Pkg.test("ForwardDiff"; coverage=true)';
14+
- julia --color=yes -O3 -e 'using Pkg; Pkg.add("StaticArrays"); include("test/SIMDTest.jl")';
1515
after_success:
16-
- julia -e 'Pkg.add("Coverage"); using Coverage; Coveralls.submit(Coveralls.process_folder())'
17-
- julia -e 'Pkg.add("Documenter")'
18-
- julia -e 'include("docs/make.jl")'
16+
- julia --color=yes -e 'Pkg.add("Coverage"); using Coverage; Coveralls.submit(Coveralls.process_folder())'
17+
- julia --color=yes -e 'Pkg.add("Documenter")'
18+
- julia --color=yes -e 'include("docs/make.jl")'

src/dual.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,18 @@ for f in (:(Base.:^), :(NaNMath.pow))
426426
end
427427
end
428428

429+
@inline Base.literal_pow(::typeof(^), x::Dual{T}, ::Val{0}) where {T} =
430+
Dual{T}(one(value(x)), zero(partials(x)))
431+
432+
for y in 1:3
433+
@eval @inline function Base.literal_pow(::typeof(^), x::Dual{T}, ::Val{$y}) where {T}
434+
v = value(x)
435+
expv = v^$y
436+
deriv = $y * v^$(y - 1)
437+
return Dual{T}(expv, deriv * partials(x))
438+
end
439+
end
440+
429441
# hypot #
430442
#-------#
431443

test/DualTest.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -460,9 +460,10 @@ end
460460
x1 = Dual{:t1}(x0, 1.0)
461461
x2 = Dual{:t2}(x1, 1.0)
462462
x3 = Dual{:t3}(x2, 1.0)
463-
@test x3^2 === x3 * x3
464-
@test x2^1 === x2
465-
@test x1^0 === Dual{:t1}(1.0, 0.0)
463+
pow = ^ # to call non-literal power
464+
@test pow(x3, 2) === x3^2 === x3 * x3
465+
@test pow(x2, 1) === x2^1 === x2
466+
@test pow(x1, 0) === x1^0 === Dual{:t1}(1.0, 0.0)
466467
end
467468

468469
end # module

test/SIMDTest.jl

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module SIMDTest
33
using Test
44
using ForwardDiff: Dual, valtype
55
using InteractiveUtils: code_llvm
6+
using StaticArrays: SVector
67

78
const DUALS = (Dual(1., 2., 3., 4.),
89
Dual(1., 2., 3., 4., 5.),
@@ -17,7 +18,7 @@ function simd_sum(x::Vector{T}) where T
1718
return s
1819
end
1920

20-
for D in map(typeof, DUALS)
21+
@testset "SIMD $D" for D in map(typeof, DUALS)
2122
plus_bitcode = sprint(io -> code_llvm(io, +, (D, D)))
2223
@test occursin("fadd <4 x double>", plus_bitcode)
2324

@@ -32,6 +33,9 @@ for D in map(typeof, DUALS)
3233
@test occursin(r"fadd \<.*?x double\>", div_bitcode)
3334
@test occursin(r"fmul \<.*?x double\>", div_bitcode)
3435

36+
pow_bitcode = sprint(io -> code_llvm(io, ^, (D, Int)))
37+
@test occursin(r"fmul \<.*?x double\>", pow_bitcode)
38+
3539
exp_bitcode = sprint(io -> code_llvm(io, ^, (D, D)))
3640
@test occursin(r"fadd \<.*?x double\>", exp_bitcode)
3741
if !(valtype(D) <: Dual)
@@ -44,4 +48,26 @@ for D in map(typeof, DUALS)
4448
end
4549
end
4650

51+
# `pow2dot` is chosen so that `@code_llvm pow2dot(SVector(1:1.0:4...))`
52+
# generates code with SIMD instructions.
53+
# See:
54+
# https://github.com/JuliaDiff/ForwardDiff.jl/pull/332
55+
# https://github.com/JuliaDiff/ForwardDiff.jl/pull/331#issuecomment-406107260
56+
@inline pow2(x) = x^2
57+
pow2dot(xs) = pow2.(xs)
58+
59+
# Nested dual such as `Dual(Dual(1., 2.), Dual(3., 4.))` only produces
60+
# "fmul <2 x double>" so it is excluded from the following test.
61+
const POW_DUALS = (Dual(1., 2.),
62+
Dual(1., 2., 3.),
63+
Dual(1., 2., 3., 4.),
64+
Dual(1., 2., 3., 4., 5.))
65+
66+
@testset "SIMD square of $D" for D in map(typeof, POW_DUALS)
67+
pow_bitcode = sprint(io -> code_llvm(io, pow2dot, (SVector{4, D},)))
68+
@test occursin(r"(.*fmul \<4 x double\>){2}"s, pow_bitcode)
69+
# "{2}" is for asserting that fmul has to appear at least twice:
70+
# once for `.value` and once for `.partials`.
71+
end
72+
4773
end # module

0 commit comments

Comments
 (0)