2
2
# Dual #
3
3
# #######
4
4
5
- struct Dual{T,V<: Real ,N} <: Real
5
+ """
6
+ ForwardDiff.can_dual(V::Type)
7
+
8
+ Determines whether the type V is allowed as the scalar type in a
9
+ Dual. By default, only `<:Real` types are allowed.
10
+ """
11
+ can_dual (:: Type{<:Real} ) = true
12
+ can_dual (:: Type ) = false
13
+
14
+ struct Dual{T,V,N} <: Real
6
15
value:: V
7
16
partials:: Partials{N,V}
17
+ function Dual {T, V, N} (value:: V , partials:: Partials{N, V} ) where {T, V, N}
18
+ can_dual (V) || throw_cannot_dual (V)
19
+ new {T, V, N} (value, partials)
20
+ end
8
21
end
9
22
10
23
# #############
19
32
Base. showerror (io:: IO , e:: DualMismatchError{A,B} ) where {A,B} =
20
33
print (io, " Cannot determine ordering of Dual tags $(e. a) and $(e. b) " )
21
34
35
+ @noinline function throw_cannot_dual (V:: Type )
36
+ throw (ArgumentError (" Cannot create a dual over scalar type $V ." *
37
+ " If the type behaves as a scalar, define FowardDiff.can_dual." ))
38
+ end
39
+
22
40
"""
23
41
ForwardDiff.≺(a, b)::Bool
24
42
@@ -41,38 +59,42 @@ tag can be extracted, so it should be used in the _innermost_ function.
41
59
return Dual {T} (convert (C, value), convert (Partials{N,C}, partials))
42
60
end
43
61
44
- @inline Dual {T} (value:: Real , partials:: Tuple ) where {T} = Dual {T} (value, Partials (partials))
45
- @inline Dual {T} (value:: Real , partials:: Tuple{} ) where {T} = Dual {T} (value, Partials {0,typeof(value)} (partials))
46
- @inline Dual {T} (value:: Real , partials:: Real... ) where {T} = Dual {T} (value, partials)
47
- @inline Dual {T} (value:: V , :: Chunk{N} , p:: Val{i} ) where {T,V<: Real ,N,i} = Dual {T} (value, single_seed (Partials{N,V}, p))
62
+ @inline Dual {T} (value, partials:: Tuple ) where {T} = Dual {T} (value, Partials (partials))
63
+ @inline Dual {T} (value, partials:: Tuple{} ) where {T} = Dual {T} (value, Partials {0,typeof(value)} (partials))
64
+ @inline Dual {T} (value) where {T} = Dual {T} (value, ())
65
+ @inline Dual {T} (x:: Dual{T} ) where {T} = Dual {T} (x, ())
66
+ @inline Dual {T} (value, partial1, partials... ) where {T} = Dual {T} (value, tuple (partial1, partials... ))
67
+ @inline Dual {T} (value:: V , :: Chunk{N} , p:: Val{i} ) where {T,V,N,i} = Dual {T} (value, single_seed (Partials{N,V}, p))
48
68
@inline Dual (args... ) = Dual {Nothing} (args... )
49
69
50
70
# we define these special cases so that the "constructor <--> convert" pun holds for `Dual`
51
- @inline Dual {T,V,N} (x:: Real ) where {T,V,N} = convert (Dual{T,V,N}, x)
52
- @inline Dual {T,V} (x:: Real ) where {T,V} = convert (Dual{T,V}, x)
71
+ @inline Dual {T,V,N} (x:: Dual{T,V,N} ) where {T,V,N} = x
72
+ @inline Dual {T,V,N} (x) where {T,V,N} = convert (Dual{T,V,N}, x)
73
+ @inline Dual {T,V,N} (x:: Number ) where {T,V,N} = convert (Dual{T,V,N}, x)
74
+ @inline Dual {T,V} (x) where {T,V} = convert (Dual{T,V}, x)
53
75
54
76
# #############################
55
77
# Utility/Accessor Functions #
56
78
# #############################
57
79
58
- @inline value (x:: Real ) = x
80
+ @inline value (x) = x
59
81
@inline value (d:: Dual ) = d. value
60
82
61
- @inline value (:: Type{T} , x:: Real ) where T = x
83
+ @inline value (:: Type{T} , x) where T = x
62
84
@inline value (:: Type{T} , d:: Dual{T} ) where T = value (d)
63
85
function value (:: Type{T} , d:: Dual{S} ) where {T,S}
64
86
# TODO : in the case of nested Duals, it may be possible to "transpose" the Dual objects
65
87
throw (DualMismatchError (T,S))
66
88
end
67
89
68
- @inline partials (x:: Real ) = Partials {0,typeof(x)} (tuple ())
90
+ @inline partials (x) = Partials {0,typeof(x)} (tuple ())
69
91
@inline partials (d:: Dual ) = d. partials
70
- @inline partials (x:: Real , i... ) = zero (x)
92
+ @inline partials (x, i... ) = zero (x)
71
93
@inline Base. @propagate_inbounds partials (d:: Dual , i) = d. partials[i]
72
94
@inline Base. @propagate_inbounds partials (d:: Dual , i, j) = partials (d, i). partials[j]
73
95
@inline Base. @propagate_inbounds partials (d:: Dual , i, j, k... ) = partials (partials (d, i, j), k... )
74
96
75
- @inline Base. @propagate_inbounds partials (:: Type{T} , x:: Real , i... ) where T = partials (x, i... )
97
+ @inline Base. @propagate_inbounds partials (:: Type{T} , x, i... ) where T = partials (x, i... )
76
98
@inline Base. @propagate_inbounds partials (:: Type{T} , d:: Dual{T} , i... ) where T = partials (d, i... )
77
99
partials (:: Type{T} , d:: Dual{S} , i... ) where {T,S} = throw (DualMismatchError (T,S))
78
100
289
311
# #######################
290
312
291
313
Base. @pure function Base. promote_rule (:: Type{Dual{T1,V1,N1}} ,
292
- :: Type{Dual{T2,V2,N2}} ) where {T1,V1<: Real ,N1,T2,V2<: Real ,N2}
314
+ :: Type{Dual{T2,V2,N2}} ) where {T1,V1,N1,T2,V2,N2}
293
315
# V1 and V2 might themselves be Dual types
294
316
if T2 ≺ T1
295
317
Dual{T1,promote_type (V1,Dual{T2,V2,N2}),N1}
@@ -299,26 +321,27 @@ Base.@pure function Base.promote_rule(::Type{Dual{T1,V1,N1}},
299
321
end
300
322
301
323
function Base. promote_rule (:: Type{Dual{T,A,N}} ,
302
- :: Type{Dual{T,B,N}} ) where {T,A<: Real ,B <: Real ,N}
324
+ :: Type{Dual{T,B,N}} ) where {T,A,B ,N}
303
325
return Dual{T,promote_type (A, B),N}
304
326
end
305
327
306
328
for R in (Irrational, Real, BigFloat, Bool)
307
329
if isconcretetype (R) # issue #322
308
330
@eval begin
309
- Base. promote_rule (:: Type{$R} , :: Type{Dual{T,V,N}} ) where {T,V<: Real ,N} = Dual{T,promote_type ($ R, V),N}
310
- Base. promote_rule (:: Type{Dual{T,V,N}} , :: Type{$R} ) where {T,V<: Real ,N} = Dual{T,promote_type (V, $ R),N}
331
+ Base. promote_rule (:: Type{$R} , :: Type{Dual{T,V,N}} ) where {T,V,N} = Dual{T,promote_type ($ R, V),N}
332
+ Base. promote_rule (:: Type{Dual{T,V,N}} , :: Type{$R} ) where {T,V,N} = Dual{T,promote_type (V, $ R),N}
311
333
end
312
334
else
313
335
@eval begin
314
- Base. promote_rule (:: Type{R} , :: Type{Dual{T,V,N}} ) where {R<: $R ,T,V<: Real ,N} = Dual{T,promote_type (R, V),N}
315
- Base. promote_rule (:: Type{Dual{T,V,N}} , :: Type{R} ) where {T,V<: Real ,N,R<: $R } = Dual{T,promote_type (V, R),N}
336
+ Base. promote_rule (:: Type{R} , :: Type{Dual{T,V,N}} ) where {R<: $R ,T,V,N} = Dual{T,promote_type (R, V),N}
337
+ Base. promote_rule (:: Type{Dual{T,V,N}} , :: Type{R} ) where {T,V,N,R<: $R } = Dual{T,promote_type (V, R),N}
316
338
end
317
339
end
318
340
end
319
341
320
- Base. convert (:: Type{Dual{T,V,N}} , d:: Dual{T} ) where {T,V<: Real ,N} = Dual {T} (convert (V, value (d)), convert (Partials{N,V}, partials (d)))
321
- Base. convert (:: Type{Dual{T,V,N}} , x:: Real ) where {T,V<: Real ,N} = Dual {T} (convert (V, x), zero (Partials{N,V}))
342
+ Base. convert (:: Type{Dual{T,V,N}} , d:: Dual{T} ) where {T,V,N} = Dual {T} (convert (V, value (d)), convert (Partials{N,V}, partials (d)))
343
+ Base. convert (:: Type{Dual{T,V,N}} , x) where {T,V,N} = Dual {T} (convert (V, x), zero (Partials{N,V}))
344
+ Base. convert (:: Type{Dual{T,V,N}} , x:: Number ) where {T,V,N} = Dual {T} (convert (V, x), zero (Partials{N,V}))
322
345
Base. convert (:: Type{D} , d:: D ) where {D<: Dual } = d
323
346
324
347
Base. float (d:: Dual{T,V,N} ) where {T,V,N} = convert (Dual{T,promote_type (V, Float16),N}, d)
468
491
# fma #
469
492
# -----#
470
493
471
- @generated function calc_fma_xyz (x:: Dual{T,<:Real ,N} ,
472
- y:: Dual{T,<:Real ,N} ,
473
- z:: Dual{T,<:Real ,N} ) where {T,N}
494
+ @generated function calc_fma_xyz (x:: Dual{T,<:Any ,N} ,
495
+ y:: Dual{T,<:Any ,N} ,
496
+ z:: Dual{T,<:Any ,N} ) where {T,N}
474
497
ex = Expr (:tuple , [:(fma (value (x), partials (y)[$ i], fma (value (y), partials (x)[$ i], partials (z)[$ i]))) for i in 1 : N]. .. )
475
498
return quote
476
499
$ (Expr (:meta , :inline ))
485
508
return Dual {T} (result, _mul_partials (partials (x), partials (y), vy, vx))
486
509
end
487
510
488
- @generated function calc_fma_xz (x:: Dual{T,<:Real ,N} ,
511
+ @generated function calc_fma_xz (x:: Dual{T,<:Any ,N} ,
489
512
y:: Real ,
490
- z:: Dual{T,<:Real ,N} ) where {T,N}
513
+ z:: Dual{T,<:Any ,N} ) where {T,N}
491
514
ex = Expr (:tuple , [:(fma (partials (x)[$ i], y, partials (z)[$ i])) for i in 1 : N]. .. )
492
515
return quote
493
516
$ (Expr (:meta , :inline ))
510
533
# muladd #
511
534
# --------#
512
535
513
- @generated function calc_muladd_xyz (x:: Dual{T,<:Real ,N} ,
514
- y:: Dual{T,<:Real ,N} ,
515
- z:: Dual{T,<:Real ,N} ) where {T,N}
536
+ @generated function calc_muladd_xyz (x:: Dual{T,<:Any ,N} ,
537
+ y:: Dual{T,<:Any ,N} ,
538
+ z:: Dual{T,<:Any ,N} ) where {T,N}
516
539
ex = Expr (:tuple , [:(muladd (value (x), partials (y)[$ i], muladd (value (y), partials (x)[$ i], partials (z)[$ i]))) for i in 1 : N]. .. )
517
540
return quote
518
541
$ (Expr (:meta , :inline ))
527
550
return Dual {T} (result, _mul_partials (partials (x), partials (y), vy, vx))
528
551
end
529
552
530
- @generated function calc_muladd_xz (x:: Dual{T,<:Real ,N} ,
553
+ @generated function calc_muladd_xz (x:: Dual{T,<:Any ,N} ,
531
554
y:: Real ,
532
- z:: Dual{T,<:Real ,N} ) where {T,N}
555
+ z:: Dual{T,<:Any ,N} ) where {T,N}
533
556
ex = Expr (:tuple , [:(muladd (partials (x)[$ i], y, partials (z)[$ i])) for i in 1 : N]. .. )
534
557
return quote
535
558
$ (Expr (:meta , :inline ))
0 commit comments