Skip to content

Commit e123cb2

Browse files
Check operations in @turbo automatically with can_avx; if failure, switch to @inbounds @fastmath (#431)
* Create `safe` kwarg for `@turbo` macro Currently, this macro does nothing. * Run `can_avx` on each operator when checking loopset * Refactor `can_avx` test * Add test for `safe=true` option in `@turbo` * Remove debugging statement * Clean up preamble generation * Set `safe=false` for `@turbo` by default * Switch to more generic `can_turbo` function for safe `@turbo` * Remove `@turbo safe=true` tests from `can_avx.jl` * Create file to test `@turbo safe=true` and `can_turbo` * Compute `nargs` of instruction properly * Add missing `safe` kwarg in `vmaterialize!` * Also unpack `warncheckarg` and `safe` from UNROLL * Ensure warncheckarg and safe passed everywhere for consistency * Consistency in `UNROLL` name * Add packages required for testing to `[extras]` and `[targets]` * Add `safe` and `warncheckarg` throughout library * Remove edits to Project * Add missing imports in save `@turbo` tests * Fix call to `can_avx` * Remove nested `testset` Seems to be breaking imports. * Test that `can_avx` validates `exp` by itself * Add SpecialFunctions.jl to test * Clean up test set * Ping test * Ensure that function names in safe test are unique * Add `RetVec2Int` for julia <1.6 as `Returns()` Co-authored-by: Chris Elrod <elrodc@gmail.com> * Use `RetVec2Int()` instead of `Returns(Vec{2,Int})` Co-authored-by: Chris Elrod <elrodc@gmail.com> * push functions into prepre Co-authored-by: Chris Elrod <elrodc@gmail.com>
1 parent 1238fc8 commit e123cb2

11 files changed

+164
-43
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <elrodc@gmail.com>"]
4-
version = "0.12.128"
4+
version = "0.12.129"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/broadcast.jl

+8-6
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,7 @@ end
548548
# we have an N dimensional loop.
549549
# need to construct the LoopSet
550550
ls = LoopSet(Mod)
551-
inline, u₁, u₂, v, isbroadcast, _, rs, rc, cls, threads, warncheckarg = UNROLL
551+
inline, u₁, u₂, v, isbroadcast, _, rs, rc, cls, threads, warncheckarg, safe = UNROLL
552552
set_hw!(ls, rs, rc, cls)
553553
ls.isbroadcast = isbroadcast # maybe set `false` in a DiffEq-like `@..` macro
554554
loopsyms = [gensym!(ls, "n") for _ 1:N]
@@ -571,6 +571,7 @@ end
571571
v,
572572
threads % Int,
573573
warncheckarg,
574+
safe,
574575
)
575576
Expr(:block, Expr(:meta, :inline), sc, :dest)
576577
end
@@ -584,7 +585,7 @@ end
584585
# we have an N dimensional loop.
585586
# need to construct the LoopSet
586587
ls = LoopSet(Mod)
587-
inline, u₁, u₂, v, isbroadcast, _, rs, rc, cls, threads, warncheckarg = UNROLL
588+
inline, u₁, u₂, v, isbroadcast, _, rs, rc, cls, threads, warncheckarg, safe = UNROLL
588589
set_hw!(ls, rs, rc, cls)
589590
ls.isbroadcast = isbroadcast # maybe set `false` in a DiffEq-like `@..` macro
590591
loopsyms = [gensym!(ls, "n") for _ 1:N]
@@ -614,6 +615,7 @@ end
614615
v,
615616
threads % Int,
616617
warncheckarg,
618+
safe,
617619
),
618620
:dest′,
619621
)
@@ -626,7 +628,7 @@ end
626628
::Val{UNROLL},
627629
::Val{dontbc}
628630
) where {T<:NativeTypes,N,T2<:Number,Mod,UNROLL,dontbc}
629-
inline, u₁, u₂, v, isbroadcast, W, rs, rc, cls, threads = UNROLL
631+
inline, u₁, u₂, v, isbroadcast, W, rs, rc, cls, threads, warncheckarg, safe = UNROLL
630632
quote
631633
$(Expr(:meta, :inline))
632634
arg = T(first(bc.args))
@@ -646,7 +648,7 @@ end
646648
::Val{UNROLL},
647649
::Val{dontbc}
648650
) where {T<:NativeTypes,N,A<:AbstractArray{T,N},T2<:Number,Mod,UNROLL,dontbc}
649-
inline, u₁, u₂, v, isbroadcast, W, rs, rc, cls, threads = UNROLL
651+
inline, u₁, u₂, v, isbroadcast, W, rs, rc, cls, threads, warncheckarg, safe = UNROLL
650652
quote
651653
$(Expr(:meta, :inline))
652654
arg = T(first(bc.args))
@@ -660,8 +662,8 @@ end
660662
dest′
661663
end
662664
end
663-
@inline function vmaterialize!(dest, bc, ::Val{Mod}, ::Val{Unroll}) where {Mod,Unroll}
664-
vmaterialize!(dest, bc, Val{Mod}(), Val{Unroll}(), Val(_dontbc(bc)))
665+
@inline function vmaterialize!(dest, bc, ::Val{Mod}, ::Val{UNROLL}) where {Mod,UNROLL}
666+
vmaterialize!(dest, bc, Val{Mod}(), Val{UNROLL}(), Val(_dontbc(bc)))
665667
end
666668

667669
@inline function vmaterialize(

src/codegen/lower_threads.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ function thread_one_loops_expr(
420420
valid_thread_loop::Vector{Bool},
421421
ntmax::UInt,
422422
c::Float64,
423-
UNROLL::Tuple{Bool,Int8,Int8,Int8,Bool,Int,Int,Int,Int,UInt},
423+
UNROLL::Tuple{Bool,Int8,Int8,Int8,Bool,Int,Int,Int,Int,UInt,Int,Bool},
424424
OPS::Expr,
425425
ARF::Expr,
426426
AM::Expr,
@@ -615,7 +615,7 @@ function thread_two_loops_expr(
615615
valid_thread_loop::Vector{Bool},
616616
ntmax::UInt,
617617
c::Float64,
618-
UNROLL::Tuple{Bool,Int8,Int8,Int8,Bool,Int,Int,Int,Int,UInt},
618+
UNROLL::Tuple{Bool,Int8,Int8,Int8,Bool,Int,Int,Int,Int,UInt,Int,Bool},
619619
OPS::Expr,
620620
ARF::Expr,
621621
AM::Expr,
@@ -877,7 +877,7 @@ function valid_thread_loops(ls::LoopSet)
877877
end
878878
function avx_threads_expr(
879879
ls::LoopSet,
880-
UNROLL::Tuple{Bool,Int8,Int8,Int8,Bool,Int,Int,Int,Int,UInt},
880+
UNROLL::Tuple{Bool,Int8,Int8,Int8,Bool,Int,Int,Int,Int,UInt,Int,Bool},
881881
nt::UInt,
882882
OPS::Expr,
883883
ARF::Expr,

src/condense_loopset.jl

+63-6
Original file line numberDiff line numberDiff line change
@@ -558,9 +558,9 @@ end
558558
::StaticInt{NT},
559559
::StaticInt{CLS},
560560
) where {CNFARG,W,RS,AR,CLS,NT}
561-
inline, u₁, u₂, v, BROADCAST, thread = CNFARG
561+
inline, u₁, u₂, v, BROADCAST, thread, warncheckarg, safe = CNFARG
562562
nt = min(thread % UInt, NT % UInt)
563-
t = Expr(:tuple, inline, u₁, u₂, v, BROADCAST, W, RS, AR, CLS, nt)
563+
t = Expr(:tuple, inline, u₁, u₂, v, BROADCAST, W, RS, AR, CLS, nt, warncheckarg, safe)
564564
length(CNFARG) == 7 && push!(t.args, CNFARG[7])
565565
Expr(:call, Expr(:curly, :Val, t))
566566
end
@@ -605,6 +605,8 @@ function split_ifelse!(
605605
k::Int,
606606
inlineu₁u₂::Tuple{Bool,Int8,Int8,Int8},
607607
thread::UInt,
608+
warncheckarg::Int,
609+
safe::Bool,
608610
debug::Bool,
609611
)
610612
roots[k] = false
@@ -662,6 +664,8 @@ function split_ifelse!(
662664
copy(extra_args),
663665
inlineu₁u₂,
664666
thread,
667+
warncheckarg,
668+
safe,
665669
debug,
666670
))
667671
else
@@ -673,6 +677,8 @@ function split_ifelse!(
673677
extra_args,
674678
inlineu₁u₂,
675679
thread,
680+
warncheckarg,
681+
safe,
676682
debug,
677683
))
678684
end
@@ -685,6 +691,8 @@ function generate_call(
685691
ls::LoopSet,
686692
inlineu₁u₂::Tuple{Bool,Int8,Int8,Int8},
687693
thread::UInt,
694+
warncheckarg::Int,
695+
safe::Bool,
688696
debug::Bool,
689697
)
690698
extra_args = Expr(:tuple)
@@ -698,6 +706,8 @@ function generate_call(
698706
extra_args,
699707
inlineu₁u₂,
700708
thread,
709+
warncheckarg,
710+
safe,
701711
debug,
702712
)
703713
end
@@ -709,6 +719,8 @@ function generate_call_split(
709719
extra_args::Expr,
710720
inlineu₁u₂::Tuple{Bool,Int8,Int8,Int8},
711721
thread::UInt,
722+
warncheckarg::Int,
723+
safe::Bool,
712724
debug::Bool,
713725
)
714726
for (k, op) enumerate(operations(ls))
@@ -725,6 +737,8 @@ function generate_call_split(
725737
k,
726738
inlineu₁u₂,
727739
thread,
740+
warncheckarg,
741+
safe,
728742
debug,
729743
)
730744
end
@@ -737,6 +751,8 @@ function generate_call_split(
737751
extra_args,
738752
inlineu₁u₂,
739753
thread,
754+
warncheckarg,
755+
safe,
740756
debug,
741757
)
742758
end
@@ -750,6 +766,8 @@ function generate_call_types(
750766
extra_args::Expr,
751767
(inline, u₁, u₂, v)::Tuple{Bool,Int8,Int8,Int8},
752768
thread::UInt,
769+
warncheckarg::Int,
770+
safe::Bool,
753771
debug::Bool,
754772
)
755773
# good place to check for split
@@ -782,7 +800,7 @@ function generate_call_types(
782800
loop_syms = tuple_expr(QuoteNode, ls.loopsymbols)
783801
func = debug ? lv(:_turbo_loopset_debug) : lv(:_turbo_!)
784802
lbarg = debug ? Expr(:call, :typeof, loop_bounds) : loop_bounds
785-
configarg = (inline, u₁, u₂, v, ls.isbroadcast, thread)
803+
configarg = (inline, u₁, u₂, v, ls.isbroadcast, thread, warncheckarg, safe)
786804
unroll_param_tup =
787805
Expr(:call, lv(:avx_config_val), :(Val{$configarg}()), VECTORWIDTHSYMBOL)
788806
q = Expr(
@@ -884,6 +902,39 @@ function check_args_call(ls::LoopSet)
884902
end
885903
q
886904
end
905+
struct RetVec2Int end
906+
(::RetVec2Int)(_) = Vec{2,Int}
907+
"""
908+
can_turbo(f::Function, ::Val{NARGS})
909+
910+
Check whether a given function with a specified number of arguments
911+
can be used inside a `@turbo` loop.
912+
"""
913+
function can_turbo(f::F, ::Val{NARGS})::Bool where {F,NARGS}
914+
promoted_op = Base.promote_op(f, ntuple(RetVec2Int(), Val(NARGS))...)
915+
return promoted_op !== Union{}
916+
end
917+
918+
"""
919+
check_turbo_safe(ls::LoopSet)
920+
921+
Returns an expression of the form `true && can_turbo(op1) && can_turbo(op2) && ...`
922+
"""
923+
function check_turbo_safe(ls::LoopSet)
924+
q = Expr(:&&, true)
925+
last = q
926+
for op in operations(ls)
927+
iscompute(op) || continue
928+
c = callexpr(op.instruction)
929+
nargs = length(parents(op))
930+
push!(c.args, Val(nargs))
931+
pushfirst!(c.args, can_turbo)
932+
new_last = Expr(:&&, c)
933+
push!(last.args, new_last)
934+
last = new_last
935+
end
936+
q
937+
end
887938

888939
make_fast(q) =
889940
Expr(:macrocall, Symbol("@fastmath"), LineNumberNode(@__LINE__, Symbol(@__FILE__)), q)
@@ -956,7 +1007,7 @@ function setup_call_final(ls::LoopSet, q::Expr)
9561007
return ls.preamble
9571008
end
9581009
function setup_call_debug(ls::LoopSet)
959-
generate_call(ls, (false, zero(Int8), zero(Int8), zero(Int8)), zero(UInt), true)
1010+
generate_call(ls, (false, zero(Int8), zero(Int8), zero(Int8)), zero(UInt), 1, true, true)
9601011
end
9611012
function setup_call(
9621013
ls::LoopSet,
@@ -969,6 +1020,7 @@ function setup_call(
9691020
v::Int8,
9701021
thread::Int,
9711022
warncheckarg::Int,
1023+
safe::Bool,
9721024
)
9731025
# We outline/inline at the macro level by creating/not creating an anonymous function.
9741026
# The old API instead was based on inlining or not inline the generated function, but
@@ -977,7 +1029,7 @@ function setup_call(
9771029
# inlining the generated function into the loop preamble.
9781030
lnns = extract_all_lnns(q)
9791031
pushfirst!(lnns, source)
980-
call = generate_call(ls, (inline, u₁, u₂, v), thread % UInt, false)
1032+
call = generate_call(ls, (inline, u₁, u₂, v), thread % UInt, 1, true, false)
9811033
call = check_empty ? check_if_empty(ls, call) : call
9821034
argfailure = make_crashy(make_fast(q))
9831035
if warncheckarg 0
@@ -986,7 +1038,12 @@ function setup_call(
9861038
warncheckarg > 0 && push!(warning.args, :(maxlog = $warncheckarg))
9871039
argfailure = Expr(:block, warning, argfailure)
9881040
end
989-
pushprepreamble!(ls, Expr(:if, check_args_call(ls), call, argfailure))
1041+
call_check = if safe
1042+
Expr(:&&, check_args_call(ls), check_turbo_safe(ls))
1043+
else
1044+
check_args_call(ls)
1045+
end
1046+
pushprepreamble!(ls, Expr(:if, call_check, call, argfailure))
9901047
prepend_lnns!(ls.prepreamble, lnns)
9911048
return ls.prepreamble
9921049
end

src/constructors.jl

+15-10
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,13 @@ function substitute_broadcast(
5252
v::Int8,
5353
threads::Int,
5454
warncheckarg::Int,
55+
safe::Bool,
5556
)
5657
ci = first(Meta.lower(LoopVectorization, q).args).code
5758
nargs = length(ci) - 1
5859
ex = Expr(:block)
5960
syms = [gensym() for _ 1:nargs]
60-
configarg = (inline, u₁, u₂, v, true, threads, warncheckarg)
61+
configarg = (inline, u₁, u₂, v, true, threads, warncheckarg, safe)
6162
unroll_param_tup = Expr(:call, lv(:avx_config_val), :(Val{$configarg}()), staticexpr(0))
6263
for n 1:nargs
6364
ciₙ = ci[n]
@@ -102,6 +103,7 @@ function check_macro_kwarg(
102103
v::Int8,
103104
threads::Int,
104105
warncheckarg::Int,
106+
safe::Bool,
105107
)
106108
((arg.head === :(=)) && (length(arg.args) == 2)) ||
107109
throw(ArgumentError("macro kwarg should be of the form `argname = value`."))
@@ -132,14 +134,16 @@ function check_macro_kwarg(
132134
end
133135
elseif kw === :warn_check_args
134136
warncheckarg = convert(Int, value)::Int
137+
elseif kw === :safe
138+
safe = convert(Bool, value)
135139
else
136140
throw(
137141
ArgumentError(
138142
"Received unrecognized keyword argument $kw. Recognized arguments include:\n`inline`, `unroll`, `check_empty`, and `thread`.",
139143
),
140144
)
141145
end
142-
inline, check_empty, u₁, u₂, v, threads, warncheckarg
146+
inline, check_empty, u₁, u₂, v, threads, warncheckarg, safe
143147
end
144148
function process_args(
145149
args;
@@ -150,12 +154,13 @@ function process_args(
150154
v::Int8 = zero(Int8),
151155
threads::Int = 1,
152156
warncheckarg::Int = 1,
157+
safe::Bool = false,
153158
)
154159
for arg args
155-
inline, check_empty, u₁, u₂, v, threads, warncheckarg =
156-
check_macro_kwarg(arg, inline, check_empty, u₁, u₂, v, threads, warncheckarg)
160+
inline, check_empty, u₁, u₂, v, threads, warncheckarg, safe =
161+
check_macro_kwarg(arg, inline, check_empty, u₁, u₂, v, threads, warncheckarg, safe)
157162
end
158-
inline, check_empty, u₁, u₂, v, threads, warncheckarg
163+
inline, check_empty, u₁, u₂, v, threads, warncheckarg, safe
159164
end
160165
# check if the body of loop is a block, if not convert it to a block issue#395
161166
# and check if the range of loop is an enumerate, if it is replace it, issue#393
@@ -225,12 +230,12 @@ function turbo_macro(mod, src, q, args...)
225230
q = macroexpand(mod, q)
226231
if q.head === :for
227232
ls = LoopSet(q, mod)
228-
inline, check_empty, u₁, u₂, v, threads, warncheckarg = process_args(args)
229-
esc(setup_call(ls, q, src, inline, check_empty, u₁, u₂, v, threads, warncheckarg))
233+
inline, check_empty, u₁, u₂, v, threads, warncheckarg, safe = process_args(args)
234+
esc(setup_call(ls, q, src, inline, check_empty, u₁, u₂, v, threads, warncheckarg, safe))
230235
else
231-
inline, check_empty, u₁, u₂, v, threads, warncheckarg =
236+
inline, check_empty, u₁, u₂, v, threads, warncheckarg, safe =
232237
process_args(args, inline = true)
233-
substitute_broadcast(q, Symbol(mod), inline, u₁, u₂, v, threads, warncheckarg)
238+
substitute_broadcast(q, Symbol(mod), inline, u₁, u₂, v, threads, warncheckarg, safe)
234239
end
235240
end
236241
"""
@@ -367,7 +372,7 @@ macro _turbo(arg, q)
367372
@assert q.head === :for
368373
q = macroexpand(__module__, q)
369374
inline, check_empty, u₁, u₂, v =
370-
check_macro_kwarg(arg, false, false, zero(Int8), zero(Int8), zero(Int8), 1, 0)
375+
check_macro_kwarg(arg, false, false, zero(Int8), zero(Int8), zero(Int8), 1, 0, true)
371376
ls = LoopSet(q, __module__)
372377
set_hw!(ls)
373378
def_outer_reduct_types!(ls)

src/modeling/graphs.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -1283,7 +1283,7 @@ function instruction!(ls::LoopSet, x::Expr)
12831283
instr keys(COST) && return Instruction(:LoopVectorization, instr)
12841284
# end
12851285
instr = gensym!(ls, "f")
1286-
pushpreamble!(ls, Expr(:(=), instr, x))
1286+
pushprepreamble!(ls, Expr(:(=), instr, x))
12871287
Instruction(Symbol(""), instr)
12881288
end
12891289
instruction!(ls::LoopSet, x::Symbol) = instruction(x)
@@ -1481,7 +1481,7 @@ function add_operation!(
14811481
add_comparison!(ls, LHS_sym, RHS, elementbytes, position)
14821482
else
14831483
throw(LoopError("Expression not recognized.", RHS))
1484-
end
1484+
end
14851485
end
14861486

14871487
function prepare_rhs_for_storage!(

0 commit comments

Comments
 (0)