Skip to content

Commit ea2d6f0

Browse files
committed
Add cse-ed ops to opdict, unroll some short static vectors.
1 parent 133bd96 commit ea2d6f0

File tree

4 files changed

+68
-8
lines changed

4 files changed

+68
-8
lines changed

src/graphs.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,10 @@ end
322322
operations(ls::LoopSet) = ls.operations
323323
function pushop!(ls::LoopSet, op::Operation, var::Symbol = name(op))
324324
for opp operations(ls)
325-
matches(op, opp) && return opp
325+
if matches(op, opp)
326+
ls.opdict[var] = opp
327+
return opp
328+
end
326329
end
327330
push!(ls.operations, op)
328331
ls.opdict[var] = op

src/lower_compute.jl

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# A compute op needs to know the unrolling and tiling status of each of its parents.
22

3+
function promote_to_231(op, unrolled, tiled)
4+
unrolleddeps = Symbol[]
5+
unrolled loopdependencies(op) && push!(unrolleddeps, unrolled)
6+
tiled loopdependencies(op) && push!(unrolleddeps, tiled)
7+
!any(opp -> isload(opp) && all(in(loopdependencies(opp)), unrolleddeps), parents(op))
8+
end
9+
310
struct FalseCollection end
411
Base.getindex(::FalseCollection, i...) = false
512
function lower_compute!(
@@ -76,7 +83,9 @@ function lower_compute!(
7683
if !isnothing(suffix) && isreduct
7784
# instrfid = findfirst(isequal(instr.instr), (:vfmadd, :vfnmadd, :vfmsub, :vfnmsub))
7885
instrfid = findfirst(isequal(instr.instr), (:vfmadd_fast, :vfnmadd_fast, :vfmsub_fast, :vfnmsub_fast))
79-
if instrfid !== nothing && !any(opp -> isload(opp) && all(in(loopdependencies(opp)), loopdependencies(op)), parents(op)) # want to instcombine when parent load's deps are superset
86+
# want to instcombine when parent load's deps are superset
87+
# also make sure opp is unrolled
88+
if instrfid !== nothing && (opunrolled && U > 1) && promote_to_231(op, unrolled, tiled)
8089
instr = Instruction((:vfmadd231, :vfnmadd231, :vfmsub231, :vfnmsub231)[instrfid])
8190
end
8291
end

src/lowering.jl

+30-6
Original file line numberDiff line numberDiff line change
@@ -172,27 +172,49 @@ function lower_unrolled_dynamic(ls::LoopSet, us::UnrollSpecification, n::Int, in
172172
nisvectorized = isvectorized(us, n)
173173
loopisstatic = isstaticloop(loop) & (!nisvectorized)
174174

175+
175176
remmask = inclmask | nisvectorized
176177
Ureduct = (n == num_loops(ls) && (u₂ == -1)) ? calc_Ureduct(ls, us) : -1
177178
sl = startloop(loop, nisvectorized, ls.W, loopsym)
178-
tc = terminatecondition(loop, us, n, ls.W, loopsym, inclmask, UF)
179-
body = lower_block(ls, us, n, inclmask, UF)
180179

180+
remfirst = loopisstatic & !(unsigned(Ureduct) < unsigned(UF))
181+
if remfirst
182+
tc = Expr(:call, lv(:scalar_less), loopsym, loop.stophint + 1)
183+
else
184+
tc = terminatecondition(loop, us, n, ls.W, loopsym, inclmask, UF)
185+
end
186+
body = lower_block(ls, us, n, inclmask, UF)
181187
q = Expr(:while, tc, body)
188+
remblock = init_remblock(loop, loopsym)
189+
UFt = if loopisstatic
190+
length(loop) % UF
191+
else
192+
1
193+
end
182194
q = if unsigned(Ureduct) < unsigned(UF) # unsigned(-1) == typemax(UInt); is logic relying on twos-complement bad?
183195
Expr(
184196
:block, sl,
185197
add_upper_outer_reductions(ls, q, Ureduct, UF, loop, vectorized),
186198
Expr(
187199
:if, terminatecondition(loop, us, n, ls.W, loopsym, inclmask, UF - Ureduct),
188200
lower_block(ls, us, n, inclmask, UF - Ureduct)
189-
)
201+
),
202+
remblock
190203
)
204+
elseif remfirst
205+
numiters = length(loop) ÷ UF
206+
if numiters > 2
207+
Expr( :block, sl, remblock, q )
208+
else
209+
q = Expr(:block, sl, remblock)
210+
for i 1:numiters
211+
push!(q.args, body)
212+
end
213+
q
214+
end
191215
else
192-
Expr( :block, sl, q )
216+
Expr( :block, sl, q, remblock )
193217
end
194-
remblock = init_remblock(loop, loopsym)
195-
push!(q.args, remblock)
196218
UFt = if loopisstatic
197219
length(loop) % UF
198220
else
@@ -206,6 +228,8 @@ function lower_unrolled_dynamic(ls::LoopSet, us::UnrollSpecification, n::Int, in
206228
Expr(:call, :-, loop.stopsym, Expr(:call, lv(:valmul), ls.W, UFt))
207229
end
208230
Expr(:call, lv(:scalar_greater), loopsym, itercount)
231+
elseif remfirst
232+
Expr(:call, lv(:scalar_less), loopsym, loop.starthint + UFt)
209233
elseif loop.stopexact
210234
Expr(:call, lv(:scalar_greater), loopsym, loop.stophint - UFt)
211235
else

test/special.jl

+24
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,26 @@
251251
end; y
252252
end
253253

254+
function csetanh!(y, z, x)
255+
for j in axes(x, 2)
256+
for i = axes(x, 1)
257+
t2 = inv(tanh(x[i, j]))
258+
t1 = tanh(x[i, j])
259+
y[i, j] = z[i, j] * (-(1 - t1 ^ 2) * t2)
260+
end
261+
end
262+
y
263+
end
264+
function csetanhavx!(y, z, x)
265+
@avx for j in axes(x, 2)
266+
for i = axes(x, 1)
267+
t2 = inv(tanh(x[i, j]))
268+
t1 = tanh(x[i, j])
269+
y[i, j] = z[i, j] * (-(1 - t1 ^ 2) * t2)
270+
end
271+
end
272+
y
273+
end
254274

255275
for T (Float32, Float64)
256276
@show T, @__LINE__
@@ -314,5 +334,9 @@
314334
@test vpowf!(r1, x, -1.7) (r2 .= x .^ -1.7)
315335
p = randn(length(x));
316336
@test vpowf!(r1, x, x) (r2 .= x .^ x)
337+
338+
X = rand(T, N, M); Z = rand(T, N, M);
339+
Y1 = similar(X); Y2 = similar(Y1);
340+
@test csetanh!(Y1, X, Z) csetanhavx!(Y2, X, Z)
317341
end
318342
end

0 commit comments

Comments
 (0)