11module ForwardDiffExt
22import ForwardDiff, ChainRulesCore
3- using SIMDDualNumbers, LoopVectorization
3+ using LoopVectorization, VectorizationBase, SLEEFPirates, ForwardDiff
4+
5+ import IfElse: ifelse
6+ using VectorizationBase: AbstractSIMD, AbstractMask, zero_offsets
7+
48using LoopVectorization:
59 AbstractSIMD,
610 AbstractStridedPointer,
@@ -18,7 +22,188 @@ using LoopVectorization:
1822 mask,
1923 vfnmadd_fast,
2024 mul_fast
21- using VectorizationBase: zero_offsets
25+
26+ @generated function Base. abs (
27+ x:: ForwardDiff.Dual{TAG,S,N}
28+ ) where {TAG,S<: AbstractSIMD ,N}
29+ quote
30+ $ (Expr (:meta , :inline ))
31+ val = x. value
32+ p = x. partials
33+ cmp = val < zero ($ S)
34+ absx = $ ifelse (cmp, - val, val)
35+ Base. Cartesian. @nexprs $ N n -> p_n = p[n]
36+ ForwardDiff. Dual {$TAG} (
37+ absx,
38+ ForwardDiff. Partials (
39+ Base. Cartesian. @ntuple $ N n -> $ ifelse (cmp, - p_n, p_n)
40+ )
41+ )
42+ end
43+ end
44+ @inline function Base. max (
45+ x:: ForwardDiff.Dual{TAG,<:AbstractSIMD,N} ,
46+ y:: ForwardDiff.Dual{TAG,<:AbstractSIMD,N}
47+ ) where {TAG,N}
48+ vx = ForwardDiff. value (x)
49+ vy = ForwardDiff. value (y)
50+ xgy = vx > vy
51+ z = ifelse (xgy, vx, vy)
52+ p = VectorizationBase. fmap (
53+ ifelse,
54+ xgy,
55+ ForwardDiff. partials (x). values,
56+ ForwardDiff. partials (y). values
57+ )
58+ ForwardDiff. Dual {TAG} (z, ForwardDiff. Partials (p))
59+ end
60+
61+ @inline Base. max (
62+ x:: T ,
63+ y:: Real
64+ ) where {N,T<: ForwardDiff.Dual{<:Any,<:AbstractSIMD,N} } = max (x, T (y))
65+ @inline Base. max (
66+ y:: Real ,
67+ x:: T
68+ ) where {N,T<: ForwardDiff.Dual{<:Any,<:AbstractSIMD,N} } = max (x, T (y))
69+ @inline Base. max (
70+ x:: T ,
71+ y:: Int
72+ ) where {N,T<: ForwardDiff.Dual{<:Any,<:AbstractSIMD,N} } = max (x, T (y))
73+ @inline Base. max (
74+ y:: Int ,
75+ x:: T
76+ ) where {N,T<: ForwardDiff.Dual{<:Any,<:AbstractSIMD,N} } = max (x, T (y))
77+
78+ @inline function Base. min (
79+ x:: ForwardDiff.Dual{TAG,<:AbstractSIMD,N} ,
80+ y:: ForwardDiff.Dual{TAG,<:AbstractSIMD,N}
81+ ) where {TAG,N}
82+ vx = ForwardDiff. value (x)
83+ vy = ForwardDiff. value (y)
84+ xgy = vx < vy
85+ z = ifelse (xgy, vx, vy)
86+ p = VectorizationBase. fmap (
87+ ifelse,
88+ xgy,
89+ ForwardDiff. partials (x). values,
90+ ForwardDiff. partials (y). values
91+ )
92+ ForwardDiff. Dual {TAG} (z, ForwardDiff. Partials (p))
93+ end
94+ @inline Base. min (
95+ x:: T ,
96+ y:: Real
97+ ) where {N,T<: ForwardDiff.Dual{<:Any,<:AbstractSIMD,N} } = min (x, T (y))
98+ @inline Base. min (
99+ y:: Real ,
100+ x:: T
101+ ) where {N,T<: ForwardDiff.Dual{<:Any,<:AbstractSIMD,N} } = min (x, T (y))
102+ @inline Base. min (
103+ x:: T ,
104+ y:: Int
105+ ) where {N,T<: ForwardDiff.Dual{<:Any,<:AbstractSIMD,N} } = min (x, T (y))
106+ @inline Base. min (
107+ y:: Int ,
108+ x:: T
109+ ) where {N,T<: ForwardDiff.Dual{<:Any,<:AbstractSIMD,N} } = min (x, T (y))
110+
111+ @generated function SLEEFPirates. tanh_fast (
112+ x:: ForwardDiff.Dual{T,S,N}
113+ ) where {T,S,N}
114+ quote
115+ $ (Expr (:meta , :inline ))
116+ t = tanh_fast (x. value)
117+ ∂t = $ (VectorizationBase. vfnmadd_fast)(t, t, one (S))
118+ p = x. partials
119+ ForwardDiff. Dual {T} (
120+ t,
121+ ForwardDiff. Partials (
122+ Base. Cartesian. @ntuple $ N n -> $ (Base. FastMath. mul_fast)(∂t, p[n])
123+ )
124+ )
125+ end
126+ end
127+ @generated function SLEEFPirates. sigmoid_fast (
128+ x:: ForwardDiff.Dual{T,S,N}
129+ ) where {T,S,N}
130+ quote
131+ $ (Expr (:meta , :inline ))
132+ s = sigmoid_fast (x. value)
133+ ∂s = $ (VectorizationBase. vfnmadd_fast)(s, s, s)
134+ p = x. partials
135+ ForwardDiff. Dual {T} (
136+ s,
137+ ForwardDiff. Partials (
138+ Base. Cartesian. @ntuple $ N n -> $ (Base. FastMath. mul_fast)(∂s, p[n])
139+ )
140+ )
141+ end
142+ end
143+ @generated function VectorizationBase. relu (
144+ x:: ForwardDiff.Dual{T,S,N}
145+ ) where {T,S,N}
146+ quote
147+ $ (Expr (:meta , :inline ))
148+ v = x. value
149+ z = zero (v)
150+ cmp = v < z
151+ r = ifelse (cmp, z, v)
152+ p = x. partials
153+ ForwardDiff. Dual {T} (
154+ r,
155+ ForwardDiff. Partials (Base. Cartesian. @ntuple $ N n -> ifelse (cmp, z, p[n]))
156+ )
157+ end
158+ end
159+
160+ @generated function ifelse (
161+ m:: AbstractMask ,
162+ x:: ForwardDiff.Dual{TAG,V,P} ,
163+ y:: ForwardDiff.Dual{TAG,V,P}
164+ ) where {TAG,V,P}
165+ quote
166+ $ (Expr (:meta , :inline ))
167+ z = $ ifelse (m, ForwardDiff. value (x), ForwardDiff. value (y))
168+ px = ForwardDiff. partials (x)
169+ py = ForwardDiff. partials (y)
170+ p = Base. Cartesian. @ntuple $ P p -> $ ifelse (m, px[p], py[p])
171+ ForwardDiff. Dual {$TAG} (z, ForwardDiff. Partials (p))
172+ end
173+ end
174+ @generated function ifelse (
175+ m:: AbstractMask ,
176+ x:: Number ,
177+ y:: ForwardDiff.Dual{TAG,V,P}
178+ ) where {TAG,V,P}
179+ quote
180+ $ (Expr (:meta , :inline ))
181+ z = $ ifelse (m, x, ForwardDiff. value (y))
182+ py = ForwardDiff. partials (y)
183+ p = Base. Cartesian. @ntuple $ P p -> $ ifelse (m, zero ($ V), py[p])
184+ ForwardDiff. Dual {$TAG} (z, ForwardDiff. Partials (p))
185+ end
186+ end
187+ @generated function ifelse (
188+ m:: AbstractMask ,
189+ x:: ForwardDiff.Dual{TAG,V,P} ,
190+ y:: Number
191+ ) where {TAG,V,P}
192+ quote
193+ $ (Expr (:meta , :inline ))
194+ z = $ ifelse (m, ForwardDiff. value (x), y)
195+ px = ForwardDiff. partials (x)
196+ p = Base. Cartesian. @ntuple $ P p -> $ ifelse (m, px[p], zero ($ V))
197+ ForwardDiff. Dual {$TAG} (z, ForwardDiff. Partials (p))
198+ end
199+ end
200+ @inline function SLEEFPirates. softplus (x:: ForwardDiff.Dual{TAG} ) where {TAG}
201+ val = ForwardDiff. value (x)
202+ expx = exp (val)
203+ vx = log1p (expx)
204+ px = Base. FastMath. inv_fast (one (val) + Base. FastMath. inv_fast (expx))
205+ ForwardDiff. Dual {TAG} (vx, Base. FastMath. mul_fast (ForwardDiff. partials (x), px))
206+ end
22207
23208@generated function init_dual (v:: Tuple{Vararg{AbstractSIMD,A}} ) where {A}
24209 res = Expr (:tuple )
0 commit comments