Skip to content

Commit 177708e

Browse files
authored
Update throughput and latency cost calc for shuffling reductions (#488)
* Update throughput and latency cost calc for shuffling reductions Also, double throughput of apple-m* * remove extra `@show` * Update gemm test -- but unsure why it should have changed
1 parent 9eb6f63 commit 177708e

File tree

2 files changed

+41
-29
lines changed

2 files changed

+41
-29
lines changed

src/modeling/determinestrategy.jl

+40-28
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ function cannot_shuffle(
7777
)
7878
))
7979
end
80+
const DOUBLE_THROUGHPUT = occursin("apple-m", LoopVectorization.get_cpu_name())
8081
function cost(
8182
ls::LoopSet,
8283
op::Operation,
@@ -104,6 +105,7 @@ function cost(
104105
# all(opp -> (isloopvalue(opp) | isconstant(opp)), parents(op))
105106
return 0.0, 0, 0.0
106107
end
108+
shuffle_rt = 0
107109
opisvectorized = isvectorized(op)
108110
srt, sl, srp =
109111
opisvectorized ? vector_cost(instr, Wshift, size_T) : scalar_cost(instr)
@@ -131,6 +133,7 @@ function cost(
131133
if isload(op) & (length(loopdependencies(op)) > 1)# vmov(a/u)pd
132134
srt += 0.5reg_size(ls) / cache_lnsze(ls)
133135
end
136+
shuffle_rt += shifter
134137
srt += shifter # shifter == number of shuffles
135138
sl += shifter
136139
end
@@ -150,7 +153,11 @@ function cost(
150153
sl *= 3
151154
end
152155
end
153-
srt, sl, Float64(srp + 1)
156+
if DOUBLE_THROUGHPUT
157+
srt *= 0.5
158+
shuffle_rt >>= 1
159+
end
160+
srt, sl, Float64(srp + 1), shuffle_rt
154161
end
155162

156163
# Base._return_type()
@@ -252,6 +259,11 @@ function depchain_cost!(
252259
rtᵢ, slᵢ = cost(ls, op, (unrolled, Symbol("")), vloopsym, Wshift, size_T)
253260
rt += rtᵢ
254261
sl += slᵢ
262+
elseif isload(op)
263+
_, _, _, shufflecost =
264+
cost(ls, op, (unrolled, Symbol("")), vloopsym, Wshift, size_T)
265+
rt += shufflecost
266+
sl += shufflecost
255267
end
256268
rt, sl
257269
end
@@ -357,7 +369,7 @@ function unroll_no_reductions(ls, order, vloopsym)
357369
# # (iszero(rt) ? 4 : max(1, roundpow2( min( 4, round(Int, 16 / rt) ) ))), unrolled
358370
# (iszero(rt) ? 4 : max(1, VectorizationBase.nextpow2( min( 4, round(Int, 8 / rt) ) ))), unrolled
359371
end
360-
function determine_unroll_factor(
372+
function rthroughput_latency(
361373
ls::LoopSet,
362374
order::Vector{Symbol},
363375
unrolled::Symbol,
@@ -390,17 +402,17 @@ function determine_unroll_factor(
390402
if isouterreduction(ls, op) -1 || unrolled reduceddependencies(op)
391403
latency = max(sl, latency)
392404
end
393-
# if unrolled ∈ loopdependencies(op)
394-
# compute_recip_throughput_u += rt
395-
# else
396405
compute_recip_throughput += rt
397-
# end
398406
elseif isload(op)
399-
load_recip_throughput +=
400-
first(cost(ls, op, (unrolled, Symbol("")), vloopsym, Wshift, size_T))
407+
lrt, _, _, shufflert =
408+
cost(ls, op, (unrolled, Symbol("")), vloopsym, Wshift, size_T)
409+
load_recip_throughput += lrt - shufflert
410+
# shufflert considered as part of depchain_cost!
401411
elseif isstore(op)
402-
store_recip_throughput +=
403-
first(cost(ls, op, (unrolled, Symbol("")), vloopsym, Wshift, size_T))
412+
srt, _, _, shufflert =
413+
cost(ls, op, (unrolled, Symbol("")), vloopsym, Wshift, size_T)
414+
store_recip_throughput += srt - shufflert
415+
compute_recip_throughput += shufflert
404416
end
405417
end
406418
recip_throughput =
@@ -447,7 +459,7 @@ function determine_unroll_factor(
447459
elseif iszero(num_reductions) # handle `BitArray` loops w/out reductions
448460
return 8 ÷ ls.vector_width, vloopsym
449461
else # handle `BitArray` loops with reductions
450-
rttemp, ltemp = determine_unroll_factor(ls, order, vloopsym, vloopsym)
462+
rttemp, ltemp = rthroughput_latency(ls, order, vloopsym, vloopsym)
451463
UF =
452464
min(8, VectorizationBase.nextpow2(max(1, round(Int, ltemp / (rttemp)))))
453465
UFfactor = 8 ÷ ls.vector_width
@@ -471,7 +483,7 @@ function determine_unroll_factor(
471483
best_unrolled = Symbol("")
472484
for unrolled order
473485
reject_reorder(ls, unrolled, false) && continue
474-
rttemp, ltemp = determine_unroll_factor(ls, order, unrolled, vloopsym)
486+
rttemp, ltemp = rthroughput_latency(ls, order, unrolled, vloopsym)
475487
rtcomptemp =
476488
rttemp + (
477489
0.01 *
@@ -1156,23 +1168,23 @@ end
11561168

11571169
update_cost_vec!(costs, cost, u₁reduces, u₂reduces) = @inbounds if u₁reduces &
11581170
u₂reduces
1159-
costs[4] += cost
1160-
elseif u₂reduces # cost decreased by unrolling u₂loop
1161-
costs[2] += cost
1162-
elseif u₁reduces # cost decreased by unrolling u₁loop
1163-
costs[3] += cost
1164-
else # no cost decrease; cost must be repeated
1165-
costs[1] += cost
1166-
end
1171+
costs[4] += cost
1172+
elseif u₂reduces # cost decreased by unrolling u₂loop
1173+
costs[2] += cost
1174+
elseif u₁reduces # cost decreased by unrolling u₁loop
1175+
costs[3] += cost
1176+
else # no cost decrease; cost must be repeated
1177+
costs[1] += cost
1178+
end
11671179
update_reg_pres!(rp, cost, u₁reduces, u₂reduces) = @inbounds if u₁reduces# & u₂reduces
1168-
rp[4] -= cost
1169-
elseif u₂reduces # cost decreased by unrolling u₂loop
1170-
rp[2] += cost
1171-
# elseif u₁reduces # cost decreased by unrolling u₁loop
1172-
# rp[4] -= cost
1173-
else # no cost decrease; cost must be repeated
1174-
rp[1] += cost
1175-
end
1180+
rp[4] -= cost
1181+
elseif u₂reduces # cost decreased by unrolling u₂loop
1182+
rp[2] += cost
1183+
# elseif u₁reduces # cost decreased by unrolling u₁loop
1184+
# rp[4] -= cost
1185+
else # no cost decrease; cost must be repeated
1186+
rp[1] += cost
1187+
end
11761188
function child_dependent_u₁u₂(op::Operation)
11771189
u₁ = u₂ = false
11781190
for opc children(op)

test/gemm.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@
412412
elseif LoopVectorization.register_count() == 16
413413
# @test LoopVectorization.choose_order(lsr2amb) == ([:m, :n, :k], :m, :n, :m, 1, 6)
414414
# @test LoopVectorization.choose_order(lsr2amb) == ([:m, :n, :k], :m, :n, :m, 2, 4)
415-
@test LoopVectorization.choose_order(lsr2amb) == ([:m, :n, :k], :n, :m, :m, 3, 3)
415+
@test LoopVectorization.choose_order(lsr2amb) == ([:n, :m, :k], :n, :m, :m, 3, 3)
416416
end
417417
function rank2AmulBavx!(C, Aₘ, Aₖ, B)
418418
@turbo for m axes(C, 1), n axes(C, 2)

0 commit comments

Comments
 (0)