Skip to content

Commit c779ae8

Browse files
authored
feat: replace enumerate (#397)
* feat: replace enumerate * fix: name iter and disable single itersym
1 parent d077140 commit c779ae8

File tree

2 files changed

+83
-6
lines changed

2 files changed

+83
-6
lines changed

src/constructors.jl

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,9 @@ end
7878

7979

8080
function LoopSet(q::Expr, mod::Symbol = :Main)
81-
contract_pass!(q)
8281
ls = LoopSet(mod)
82+
check_inputs!(q, ls.prepreamble)
83+
contract_pass!(q)
8384
copyto!(ls, q)
8485
resize!(ls.loop_order, num_loops(ls))
8586
ls
@@ -157,23 +158,70 @@ function process_args(
157158
inline, check_empty, u₁, u₂, v, threads, warncheckarg
158159
end
159160
# check if the body of loop is a block, if not convert it to a block issue#395
160-
function check_loopbody!(q)
161-
if q isa Expr && q.head == :for
161+
# and check if the range of loop is an enumerate, if it is replace it, issue#393
162+
function check_inputs!(q, prepreamble)
163+
if Meta.isexpr(q, :for)
162164
if !Meta.isexpr(q.args[2], :block)
163165
q.args[2] = Expr(:block, q.args[2])
164-
else
166+
replace_enumerate!(q, prepreamble) # must after warp block
167+
else # maybe inner loops in block
168+
replace_enumerate!(q, prepreamble)
165169
for arg in q.args[2].args
166-
check_loopbody!(arg) # check recursively for inner loop
170+
check_inputs!(arg, prepreamble) # check recursively for inner loop
167171
end
168172
end
169173
end
170174
return q
171175
end
176+
function replace_enumerate!(q, prepreamble)
177+
looprange = q.args[1]
178+
if Meta.isexpr(looprange, :block)
179+
for i in 1:length(looprange.args)
180+
replace_single_enumerate!(q, prepreamble, i)
181+
end
182+
else
183+
replace_single_enumerate!(q, prepreamble)
184+
end
185+
return q
186+
end
187+
function replace_single_enumerate!(q, prepreamble, i=nothing)
188+
if isnothing(i) # not nest loop
189+
looprange, body = q.args[1], q.args[2]
190+
else # nest loop
191+
looprange, body = q.args[1].args[i], q.args[2]
192+
end
193+
@assert Meta.isexpr(looprange, :(=), 2)
194+
itersyms, r = looprange.args
195+
if Meta.isexpr(r, :call, 2) && r.args[1] == :enumerate
196+
_iter = r.args[2]
197+
if _iter isa Symbol
198+
iter = _iter
199+
else # name complex expr
200+
iter = gensym(:iter)
201+
push!(prepreamble.args, :($iter = $_iter))
202+
end
203+
if Meta.isexpr(itersyms, :tuple, 2)
204+
indsym, varsym = itersyms.args[1]::Symbol, itersyms.args[2]::Symbol
205+
_replace_looprange!(q, i, indsym, iter)
206+
pushfirst!(body.args, :($varsym = $iter[$indsym + firstindex($iter) - 1]))
207+
elseif Meta.isexpr(itersyms, :tuple, 1) # like `for (i,) in enumerate(...)`
208+
indsym = itersyms.args[1]::Symbol
209+
_replace_looprange!(q, i, indsym, iter)
210+
elseif itersyms isa Symbol # if itersyms are not unbox in loop range
211+
throw(ArgumentError("`for $itersyms in enumerate($r)` is not supported,
212+
please use `for ($(itersyms)_i, $(itersyms)_v) in enumerate($r)` instead."))
213+
else
214+
throw(ArgumentError("Don't know how to handle expression `$itersyms`."))
215+
end
216+
end
217+
return q
218+
end
219+
_replace_looprange!(q, ::Nothing, indsym, iter) = q.args[1] = :($indsym = Base.OneTo(length($iter)))
220+
_replace_looprange!(q, i::Int, indsym, iter) = q.args[1].args[i] = :($indsym = Base.OneTo(length($iter)))
172221

173222
function turbo_macro(mod, src, q, args...)
174223
q = macroexpand(mod, q)
175224
if q.head === :for
176-
check_loopbody!(q)
177225
ls = LoopSet(q, mod)
178226
inline, check_empty, u₁, u₂, v, threads, warncheckarg = process_args(args)
179227
esc(setup_call(ls, q, src, inline, check_empty, u₁, u₂, v, threads, warncheckarg))

test/parsing_inputs.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using LoopVectorization, Test, ArrayInterface
2+
using LoopVectorization: check_inputs!
23

34
# macros for generate loops whose body is not a block
45
macro gen_loop_issue395(ex)
@@ -48,3 +49,31 @@ end
4849
@test E == A * B'
4950
@test F == C * E
5051
end
52+
53+
@testset "enumerate, #393" begin
54+
A = zeros(4)
55+
B = zeros(4)
56+
C = zeros(4, 4)
57+
D = zeros(4, 4)
58+
@turbo for (i, x) in enumerate(1:4)
59+
A[i] = x
60+
end
61+
@turbo for (i,) in enumerate(B)
62+
B[i] += 1
63+
end
64+
@turbo for (j, Aj) in enumerate(A), (i, Bi) in enumerate(B)
65+
C[i, j] = Aj * Bi
66+
end
67+
@turbo for (j, Bj) in enumerate(B)
68+
for (i, Ai) in enumerate(A)
69+
D[i, j] = Ai * Bj
70+
end
71+
end
72+
@test A == 1:4
73+
@test B == ones(4)
74+
@test A .* B' == C' == D
75+
@test_throws ArgumentError check_inputs!(:(for ix in enumerate(A)
76+
A[ix[1]] = ix[1] + ix[2]
77+
end), Any[])
78+
@test_throws ArgumentError check_inputs!(:(for () in enumerate(A); end), Any[])
79+
end

0 commit comments

Comments
 (0)