Skip to content

Commit c5974ee

Browse files
Kenojrevels
authored andcommitted
Remove V <: Real type restriction (#369)
Instead, use an extensible function that the constructor uses to check whether a type is valid to be used as `Dual`'s scalar type. Fixes #216
1 parent e1a129b commit c5974ee

File tree

1 file changed

+53
-30
lines changed

1 file changed

+53
-30
lines changed

src/dual.jl

Lines changed: 53 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,22 @@
22
# Dual #
33
########
44

5-
struct Dual{T,V<:Real,N} <: Real
5+
"""
6+
ForwardDiff.can_dual(V::Type)
7+
8+
Determines whether the type V is allowed as the scalar type in a
9+
Dual. By default, only `<:Real` types are allowed.
10+
"""
11+
can_dual(::Type{<:Real}) = true
12+
can_dual(::Type) = false
13+
14+
struct Dual{T,V,N} <: Real
615
value::V
716
partials::Partials{N,V}
17+
function Dual{T, V, N}(value::V, partials::Partials{N, V}) where {T, V, N}
18+
can_dual(V) || throw_cannot_dual(V)
19+
new{T, V, N}(value, partials)
20+
end
821
end
922

1023
##############
@@ -19,6 +32,11 @@ end
1932
Base.showerror(io::IO, e::DualMismatchError{A,B}) where {A,B} =
2033
print(io, "Cannot determine ordering of Dual tags $(e.a) and $(e.b)")
2134

35+
@noinline function throw_cannot_dual(V::Type)
36+
throw(ArgumentError("Cannot create a dual over scalar type $V." *
37+
" If the type behaves as a scalar, define FowardDiff.can_dual."))
38+
end
39+
2240
"""
2341
ForwardDiff.≺(a, b)::Bool
2442
@@ -41,38 +59,42 @@ tag can be extracted, so it should be used in the _innermost_ function.
4159
return Dual{T}(convert(C, value), convert(Partials{N,C}, partials))
4260
end
4361

44-
@inline Dual{T}(value::Real, partials::Tuple) where {T} = Dual{T}(value, Partials(partials))
45-
@inline Dual{T}(value::Real, partials::Tuple{}) where {T} = Dual{T}(value, Partials{0,typeof(value)}(partials))
46-
@inline Dual{T}(value::Real, partials::Real...) where {T} = Dual{T}(value, partials)
47-
@inline Dual{T}(value::V, ::Chunk{N}, p::Val{i}) where {T,V<:Real,N,i} = Dual{T}(value, single_seed(Partials{N,V}, p))
62+
@inline Dual{T}(value, partials::Tuple) where {T} = Dual{T}(value, Partials(partials))
63+
@inline Dual{T}(value, partials::Tuple{}) where {T} = Dual{T}(value, Partials{0,typeof(value)}(partials))
64+
@inline Dual{T}(value) where {T} = Dual{T}(value, ())
65+
@inline Dual{T}(x::Dual{T}) where {T} = Dual{T}(x, ())
66+
@inline Dual{T}(value, partial1, partials...) where {T} = Dual{T}(value, tuple(partial1, partials...))
67+
@inline Dual{T}(value::V, ::Chunk{N}, p::Val{i}) where {T,V,N,i} = Dual{T}(value, single_seed(Partials{N,V}, p))
4868
@inline Dual(args...) = Dual{Nothing}(args...)
4969

5070
# we define these special cases so that the "constructor <--> convert" pun holds for `Dual`
51-
@inline Dual{T,V,N}(x::Real) where {T,V,N} = convert(Dual{T,V,N}, x)
52-
@inline Dual{T,V}(x::Real) where {T,V} = convert(Dual{T,V}, x)
71+
@inline Dual{T,V,N}(x::Dual{T,V,N}) where {T,V,N} = x
72+
@inline Dual{T,V,N}(x) where {T,V,N} = convert(Dual{T,V,N}, x)
73+
@inline Dual{T,V,N}(x::Number) where {T,V,N} = convert(Dual{T,V,N}, x)
74+
@inline Dual{T,V}(x) where {T,V} = convert(Dual{T,V}, x)
5375

5476
##############################
5577
# Utility/Accessor Functions #
5678
##############################
5779

58-
@inline value(x::Real) = x
80+
@inline value(x) = x
5981
@inline value(d::Dual) = d.value
6082

61-
@inline value(::Type{T}, x::Real) where T = x
83+
@inline value(::Type{T}, x) where T = x
6284
@inline value(::Type{T}, d::Dual{T}) where T = value(d)
6385
function value(::Type{T}, d::Dual{S}) where {T,S}
6486
# TODO: in the case of nested Duals, it may be possible to "transpose" the Dual objects
6587
throw(DualMismatchError(T,S))
6688
end
6789

68-
@inline partials(x::Real) = Partials{0,typeof(x)}(tuple())
90+
@inline partials(x) = Partials{0,typeof(x)}(tuple())
6991
@inline partials(d::Dual) = d.partials
70-
@inline partials(x::Real, i...) = zero(x)
92+
@inline partials(x, i...) = zero(x)
7193
@inline Base.@propagate_inbounds partials(d::Dual, i) = d.partials[i]
7294
@inline Base.@propagate_inbounds partials(d::Dual, i, j) = partials(d, i).partials[j]
7395
@inline Base.@propagate_inbounds partials(d::Dual, i, j, k...) = partials(partials(d, i, j), k...)
7496

75-
@inline Base.@propagate_inbounds partials(::Type{T}, x::Real, i...) where T = partials(x, i...)
97+
@inline Base.@propagate_inbounds partials(::Type{T}, x, i...) where T = partials(x, i...)
7698
@inline Base.@propagate_inbounds partials(::Type{T}, d::Dual{T}, i...) where T = partials(d, i...)
7799
partials(::Type{T}, d::Dual{S}, i...) where {T,S} = throw(DualMismatchError(T,S))
78100

@@ -289,7 +311,7 @@ end
289311
########################
290312

291313
Base.@pure function Base.promote_rule(::Type{Dual{T1,V1,N1}},
292-
::Type{Dual{T2,V2,N2}}) where {T1,V1<:Real,N1,T2,V2<:Real,N2}
314+
::Type{Dual{T2,V2,N2}}) where {T1,V1,N1,T2,V2,N2}
293315
# V1 and V2 might themselves be Dual types
294316
if T2 T1
295317
Dual{T1,promote_type(V1,Dual{T2,V2,N2}),N1}
@@ -299,26 +321,27 @@ Base.@pure function Base.promote_rule(::Type{Dual{T1,V1,N1}},
299321
end
300322

301323
function Base.promote_rule(::Type{Dual{T,A,N}},
302-
::Type{Dual{T,B,N}}) where {T,A<:Real,B<:Real,N}
324+
::Type{Dual{T,B,N}}) where {T,A,B,N}
303325
return Dual{T,promote_type(A, B),N}
304326
end
305327

306328
for R in (Irrational, Real, BigFloat, Bool)
307329
if isconcretetype(R) # issue #322
308330
@eval begin
309-
Base.promote_rule(::Type{$R}, ::Type{Dual{T,V,N}}) where {T,V<:Real,N} = Dual{T,promote_type($R, V),N}
310-
Base.promote_rule(::Type{Dual{T,V,N}}, ::Type{$R}) where {T,V<:Real,N} = Dual{T,promote_type(V, $R),N}
331+
Base.promote_rule(::Type{$R}, ::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T,promote_type($R, V),N}
332+
Base.promote_rule(::Type{Dual{T,V,N}}, ::Type{$R}) where {T,V,N} = Dual{T,promote_type(V, $R),N}
311333
end
312334
else
313335
@eval begin
314-
Base.promote_rule(::Type{R}, ::Type{Dual{T,V,N}}) where {R<:$R,T,V<:Real,N} = Dual{T,promote_type(R, V),N}
315-
Base.promote_rule(::Type{Dual{T,V,N}}, ::Type{R}) where {T,V<:Real,N,R<:$R} = Dual{T,promote_type(V, R),N}
336+
Base.promote_rule(::Type{R}, ::Type{Dual{T,V,N}}) where {R<:$R,T,V,N} = Dual{T,promote_type(R, V),N}
337+
Base.promote_rule(::Type{Dual{T,V,N}}, ::Type{R}) where {T,V,N,R<:$R} = Dual{T,promote_type(V, R),N}
316338
end
317339
end
318340
end
319341

320-
Base.convert(::Type{Dual{T,V,N}}, d::Dual{T}) where {T,V<:Real,N} = Dual{T}(convert(V, value(d)), convert(Partials{N,V}, partials(d)))
321-
Base.convert(::Type{Dual{T,V,N}}, x::Real) where {T,V<:Real,N} = Dual{T}(convert(V, x), zero(Partials{N,V}))
342+
Base.convert(::Type{Dual{T,V,N}}, d::Dual{T}) where {T,V,N} = Dual{T}(convert(V, value(d)), convert(Partials{N,V}, partials(d)))
343+
Base.convert(::Type{Dual{T,V,N}}, x) where {T,V,N} = Dual{T}(convert(V, x), zero(Partials{N,V}))
344+
Base.convert(::Type{Dual{T,V,N}}, x::Number) where {T,V,N} = Dual{T}(convert(V, x), zero(Partials{N,V}))
322345
Base.convert(::Type{D}, d::D) where {D<:Dual} = d
323346

324347
Base.float(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,promote_type(V, Float16),N}, d)
@@ -468,9 +491,9 @@ end
468491
# fma #
469492
#-----#
470493

471-
@generated function calc_fma_xyz(x::Dual{T,<:Real,N},
472-
y::Dual{T,<:Real,N},
473-
z::Dual{T,<:Real,N}) where {T,N}
494+
@generated function calc_fma_xyz(x::Dual{T,<:Any,N},
495+
y::Dual{T,<:Any,N},
496+
z::Dual{T,<:Any,N}) where {T,N}
474497
ex = Expr(:tuple, [:(fma(value(x), partials(y)[$i], fma(value(y), partials(x)[$i], partials(z)[$i]))) for i in 1:N]...)
475498
return quote
476499
$(Expr(:meta, :inline))
@@ -485,9 +508,9 @@ end
485508
return Dual{T}(result, _mul_partials(partials(x), partials(y), vy, vx))
486509
end
487510

488-
@generated function calc_fma_xz(x::Dual{T,<:Real,N},
511+
@generated function calc_fma_xz(x::Dual{T,<:Any,N},
489512
y::Real,
490-
z::Dual{T,<:Real,N}) where {T,N}
513+
z::Dual{T,<:Any,N}) where {T,N}
491514
ex = Expr(:tuple, [:(fma(partials(x)[$i], y, partials(z)[$i])) for i in 1:N]...)
492515
return quote
493516
$(Expr(:meta, :inline))
@@ -510,9 +533,9 @@ end
510533
# muladd #
511534
#--------#
512535

513-
@generated function calc_muladd_xyz(x::Dual{T,<:Real,N},
514-
y::Dual{T,<:Real,N},
515-
z::Dual{T,<:Real,N}) where {T,N}
536+
@generated function calc_muladd_xyz(x::Dual{T,<:Any,N},
537+
y::Dual{T,<:Any,N},
538+
z::Dual{T,<:Any,N}) where {T,N}
516539
ex = Expr(:tuple, [:(muladd(value(x), partials(y)[$i], muladd(value(y), partials(x)[$i], partials(z)[$i]))) for i in 1:N]...)
517540
return quote
518541
$(Expr(:meta, :inline))
@@ -527,9 +550,9 @@ end
527550
return Dual{T}(result, _mul_partials(partials(x), partials(y), vy, vx))
528551
end
529552

530-
@generated function calc_muladd_xz(x::Dual{T,<:Real,N},
553+
@generated function calc_muladd_xz(x::Dual{T,<:Any,N},
531554
y::Real,
532-
z::Dual{T,<:Real,N}) where {T,N}
555+
z::Dual{T,<:Any,N}) where {T,N}
533556
ex = Expr(:tuple, [:(muladd(partials(x)[$i], y, partials(z)[$i])) for i in 1:N]...)
534557
return quote
535558
$(Expr(:meta, :inline))

0 commit comments

Comments
 (0)