Skip to content

Commit 8679bf4

Browse files
committed
Fixed bug where Amn = A[m,n] caused an error when m and n are constants.
1 parent 30a9ddf commit 8679bf4

9 files changed

+45
-25
lines changed

src/add_compute.jl

+8
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,14 @@ function add_parent!(
5151
vparents::Vector{Operation}, deps::Vector{Symbol}, reduceddeps::Vector{Symbol}, ls::LoopSet, var, elementbytes::Int, position::Int
5252
)
5353
parent = if var isa Symbol
54+
# if var === :kern_1_1
55+
# @show operations(ls) ls.preamble_symsym
56+
# end
5457
opp = getop(ls, var, elementbytes)
58+
# if var === :kern_1_1
59+
# @show operations(ls) ls.preamble_symsym
60+
# end
61+
# @show var opp first(operations(ls)) opp === first(operations(ls))
5562
if iscompute(opp) && instruction(opp).instr === :identity && length(loopdependencies(opp)) < position && isone(length(parents(opp))) && name(opp) === name(first(parents(opp)))
5663
first(parents(opp))
5764
else
@@ -220,6 +227,7 @@ function add_compute!(
220227
deps = Symbol[]
221228
reduceddeps = Symbol[]
222229
reduction_ind = 0
230+
# @show ex first(operations(ls)) === getop(ls, :kern_1_1, elementbytes) first(operations(ls)) getop(ls, :kern_1_1, elementbytes)
223231
for (ind,arg) enumerate(args)
224232
if var === arg
225233
reduction_ind = ind

src/add_constants.jl

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
function add_constant!(ls::LoopSet, var::Symbol, elementbytes::Int)
22
op = Operation(length(operations(ls)), var, elementbytes, LOOPCONSTANT, constant, NODEPENDENCY, Symbol[], NOPARENTS)
3-
pushpreamble!(ls, op, var)
4-
pushop!(ls, op, var)
3+
rop = pushop!(ls, op, var)
4+
rop === op && pushpreamble!(ls, op, var)
5+
rop
56
end
67
# function add_constant!(ls::LoopSet, var, elementbytes::Int = 8)
78
# sym = gensym(:loopconstant)
@@ -12,6 +13,8 @@ function add_constant!(ls::LoopSet, var::Number, elementbytes::Int = 8)
1213
op = Operation(length(operations(ls)), gensym(:loopconstnumber), elementbytes, LOOPCONSTANT, constant, NODEPENDENCY, Symbol[], NOPARENTS)
1314
ops = operations(ls)
1415
typ = var isa Integer ? HardInt : HardFloat
16+
rop = pushop!(ls, op)
17+
rop !== op && return rop
1518
if iszero(var)
1619
for (id,typ_) ls.preamble_zeros
1720
(instruction(ops[id]) === LOOPCONSTANT && typ == typ_) && return ops[id]
@@ -28,7 +31,7 @@ function add_constant!(ls::LoopSet, var::Number, elementbytes::Int = 8)
2831
end
2932
push!(ls.preamble_symfloat, (identifier(op), var))
3033
end
31-
pushop!(ls, op)
34+
rop
3235
end
3336
function add_constant!(ls::LoopSet, mpref::ArrayReferenceMetaPosition, elementbytes::Int)
3437
op = Operation(length(operations(ls)), varname(mpref), elementbytes, LOOPCONSTANT, constant, NODEPENDENCY, Symbol[], NOPARENTS, mpref.mref)

src/add_loads.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ end
2323
function add_load!(
2424
ls::LoopSet, mpref::ArrayReferenceMetaPosition, elementbytes::Int
2525
)
26-
length(mpref.loopdependencies) == 0 && return add_constant!(ls, mpref, elementbytes)
26+
iszero(length(mpref.loopdependencies)) && return add_constant!(ls, mpref, elementbytes)
2727
op = Operation( ls, varname(mpref), elementbytes, :getindex, memload, mpref )
2828
add_load!(ls, op, true, false)
2929
end

src/graphs.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -522,9 +522,9 @@ function add_loop!(ls::LoopSet, loop::Loop, itersym::Symbol = loop.itersymbol)
522522
nothing
523523
end
524524

525-
function instruction(x)
526-
x isa Symbol ? x : last(x.args).value
527-
end
525+
# function instruction(x)
526+
# x isa Symbol ? x : last(x.args).value
527+
# end
528528
# instruction(ls::LoopSet, f::Symbol) = instruction!(ls, f)
529529
function instruction!(ls::LoopSet, x::Expr)
530530
x isa Symbol && return x

src/lower_compute.jl

+4-11
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,11 @@
11

2+
23
function load_constrained(op, u₁loop, u₂loop, innermost_loop, forprefetch = false)
34
loopdeps = loopdependencies(op)
45
dependsonu₁ = u₁loop loopdeps
5-
if u₂loop === Symbol("##undefined##")
6-
if forprefetch
7-
dependsonu₁ || return false
8-
end
9-
# unrolleddeps = [ u₁loop ]
10-
else
11-
dependsonu₂ = u₂loop loopdeps
12-
if forprefetch
13-
(dependsonu₁ & dependsonu₂) || return false
14-
end
15-
# unrolleddeps = [ u₁loop, u₂loop ]
6+
dependsonu₂ = u₂loop loopdeps
7+
if forprefetch
8+
(dependsonu₁ & dependsonu₂) || return false
169
end
1710
unrolleddeps = Symbol[]
1811
dependsonu₁ && push!(unrolleddeps, u₁loop)

src/operations.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ function matches(op1::Operation, op2::Operation)
206206
op1.instruction === op2.instruction || return false
207207
op1.node_type == op2.node_type || return false
208208
if isconstant(op1)
209-
return false
209+
return iszero(length(loopdependencies(op1))) && iszero(length(loopdependencies(op2))) && (mangledvar(op1) === mangledvar(op2))
210210
end
211211
op1.dependencies == op2.dependencies || return false
212212
op2.reduced_deps == op2.reduced_deps || return false

src/reconstruct_loopset.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ function add_op!(
306306
# If it's a CartesianIndex add or subtract, we may have to add multiple operations
307307
expanded = expandedv[i]# isexpanded(ls, ops, nopsv, i)
308308
opoffsets = ls.operation_offsets
309-
offsets = ls.loopsymbol_offsets
309+
# offsets = ls.loopsymbol_offsets
310310
optyp = optype(os)
311311
if !expanded
312312
op = Operation(

test/gemm.jl

+16-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,22 @@
4747
if LoopVectorization.REGISTER_COUNT != 8
4848
@test LoopVectorization.choose_order(lsAmulB3) == (Symbol[:n,:m,:k], :m, :n, :m, Unum, Tnum)
4949
end
50-
50+
if LoopVectorization.REGISTER_COUNT != 8
51+
for (fA,fB,v,Un,Tn) [(identity,identity,:m,Unum,Tnum),(adjoint,identity,:k,Unumt,Tnumt),(identity,adjoint,:m,Unum,Tnum),(adjoint,adjoint,:n,Unum,Tnum)]
52+
A = fA(rand(2,2))
53+
B = fB(rand(2,2))
54+
C = similar(A)
55+
ls = LoopVectorization.@avx_debug for m axes(A,1), n axes(B,2)
56+
ΔCₘₙ = zero(eltype(C))
57+
for k axes(A,2)
58+
ΔCₘₙ += A[m,k] * B[k,n]
59+
end
60+
C[m,n] += ΔCₘₙ
61+
end
62+
(m, n) = v === :n ? (:n, :m) : (:m, :n)
63+
@test LoopVectorization.choose_order(ls) == (Symbol[:n,:m,:k], m, n, v, Un, Tn)
64+
end
65+
end
5166
function AmulB!(C, A, B)
5267
C .= 0
5368
for k axes(A,2), j axes(B,2)

test/offsetarrays.jl

+5-4
Original file line numberDiff line numberDiff line change
@@ -110,18 +110,19 @@ using LoopVectorization.VectorizationBase: StaticUnitRange
110110
# Manually unpack the OffsetArray
111111
@avx for j in rng2, i in rng1
112112
tmp_0 = zero(eltype(out))
113-
Base.Cartesian.@nexprs 3 jk -> Base.Cartesian.@nexprs 3 ik -> tmp_{ik+(jk-1)*3} = A[(ik-2)+i,(jk-2)+j] * kern_ik_jk + tmp_{ik+(jk-1)*3-1}
113+
Base.Cartesian.@nexprs 3 jk -> Base.Cartesian.@nexprs 3 ik -> tmp_{ik+(jk-1)*3} = A[(ik-2)+i,(jk-2) + j*1] * kern_ik_jk + tmp_{ik+(jk-1)*3-1}
114114
out[i,j] = tmp_9
115115
end
116116
out
117117
end
118118
function avx2dunrolled2x2!(out::AbstractMatrix, A::AbstractMatrix, kern::SizedOffsetMatrix{T,-1,1,-1,1}) where {T}
119119
rng1, rng2 = axes(out)
120-
Base.Cartesian.@nexprs 3 jk -> Base.Cartesian.@nexprs 3 ik -> kern_ik_jk = kern[ik-2,jk-2]
121120
# Manually unpack the OffsetArray
122121
@avx unroll=(2,2) for j in rng2, i in rng1
122+
Base.Cartesian.@nexprs 3 jk -> Base.Cartesian.@nexprs 3 ik -> kern_ik_jk = kern[ik - 2, jk + (-2)]
123123
tmp_0 = zero(eltype(out))
124-
Base.Cartesian.@nexprs 3 jk -> Base.Cartesian.@nexprs 3 ik -> tmp_{ik+(jk-1)*3} = A[i+(ik-2),j+(jk-2)] * kern_ik_jk + tmp_{ik+(jk-1)*3-1}
124+
j1 = j * 1
125+
Base.Cartesian.@nexprs 3 jk -> Base.Cartesian.@nexprs 3 ik -> tmp_{ik+(jk-1)*3} = A[i + (ik-2), (jk-2) + j1] * kern_ik_jk + tmp_{ik+(jk-1)*3-1}
125126
out[i,j] = tmp_9
126127
end
127128
out
@@ -132,7 +133,7 @@ using LoopVectorization.VectorizationBase: StaticUnitRange
132133
# Manually unpack the OffsetArray
133134
@avx unroll=(3,3) for j in rng2, i in rng1
134135
tmp_0 = zero(eltype(out))
135-
Base.Cartesian.@nexprs 3 jk -> Base.Cartesian.@nexprs 3 ik -> tmp_{ik+(jk-1)*3} = A[i+(ik-2),j+(jk-2)] * kern_ik_jk + tmp_{ik+(jk-1)*3-1}
136+
Base.Cartesian.@nexprs 3 jk -> Base.Cartesian.@nexprs 3 ik -> tmp_{ik+(jk-1)*3} = A[(ik-2) + i, j*1 + (jk-2)] * kern_ik_jk + tmp_{ik+(jk-1)*3-1}
136137
out[i,j] = tmp_9
137138
end
138139
out

0 commit comments

Comments
 (0)