@@ -154,22 +154,14 @@ end
154
154
@inline choose_num_blocks (nt, :: StaticInt{NC} = lv_max_num_threads ()) where {NC} =
155
155
@inbounds choose_num_block_table (StaticInt {NC} ())[nt]
156
156
157
- if Sys. ARCH === :x86_64
158
- @inline function choose_num_threads (
157
+ scale_cost (c) = @fastmath c * (Sys. ARCH === :x86_64 ? 0.0225 : 0.005625 )
158
+ scale_cost (c, looplen) = scale_cost (@fastmath c / looplen)
159
+ @inline function choose_num_threads (
159
160
C:: T ,
160
161
NT:: UInt ,
161
162
x:: Base.BitInteger ,
162
163
) where {T<: Union{Float32,Float64} }
163
- _choose_num_threads (Base. mul_float_fast (T (C), T (0.0225 )), NT, x)
164
- end
165
- else
166
- @inline function choose_num_threads (
167
- C:: T ,
168
- NT:: UInt ,
169
- x:: Base.BitInteger ,
170
- ) where {T<: Union{Float32,Float64} }
171
- _choose_num_threads (Base. mul_float_fast (C, T (0.0225 ) * T (0.25 )), NT, x)
172
- end
164
+ _choose_num_threads (scale_cost (T (C)), NT, x)
173
165
end
174
166
@inline function _choose_num_threads (
175
167
C:: T ,
@@ -422,13 +414,6 @@ function define_block_size(threadedloop, vloop, tn, W)
422
414
end
423
415
end
424
416
end
425
- function scale_cost (c, looplen)
426
- c = 0.05 * c / looplen
427
- if Sys. ARCH != = :x86_64
428
- c *= 0.25
429
- end
430
- c
431
- end
432
417
function thread_one_loops_expr (
433
418
ls:: LoopSet ,
434
419
ua:: UnrollArgs ,
@@ -868,17 +853,20 @@ function valid_thread_loops(ls::LoopSet)
868
853
u₂loop = _u₂loop === nothing ? u₁loop : getloop_from_id (ls, _u₂loop)
869
854
ua = UnrollArgs (u₁loop, u₂loop, getloop (ls, vectorized), u₁, u₂, u₂)
870
855
valid_thread_loop = fill (true , length (order))
856
+ has_reduced_deps = false
871
857
for op ∈ operations (ls)
872
858
if isstore (op) && (length (reduceddependencies (op)) > 0 )
873
859
for reduceddep ∈ reduceddependencies (op)
874
860
for (i, o) ∈ enumerate (order)
875
861
if o === reduceddep
862
+ has_reduced_deps = true
876
863
valid_thread_loop[i] = false
877
864
end
878
865
end
879
866
end
880
867
end
881
868
end
869
+ c *= (1.0 + 0.5 has_reduced_deps)
882
870
for (i, o) ∈ enumerate (order)
883
871
loop = getloop (ls, o)
884
872
if isstaticloop (loop) & (length (loop) ≤ 1 )
0 commit comments