Skip to content

Commit a320e07

Browse files
committed
Fix 106, bump SIMDPirates to resolve 105.
1 parent 652f368 commit a320e07

13 files changed

+96
-29
lines changed

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Copyright (c) 2019 Chris Elrod
1+
Copyright (c) 2019 Chris Elrod, Eli Lilly & Co
22

33
Permission is hereby granted, free of charge, to any person obtaining a copy
44
of this software and associated documentation files (the "Software"), to deal

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ OffsetArrays = "1"
1818
SIMDPirates = "0.7.16"
1919
SLEEFPirates = "0.4.4"
2020
UnPack = "0"
21-
VectorizationBase = "0.10.5"
21+
VectorizationBase = "0.11"
2222
julia = "1.1"
2323

2424
[extras]

src/LoopVectorization.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ include("split_loops.jl")
5656
include("condense_loopset.jl")
5757
include("reconstruct_loopset.jl")
5858
include("constructors.jl")
59+
include("user_api_conveniences.jl")
5960

6061
"""
6162
LoopVectorization provides macros and functions that combine SIMD vectorization and

src/add_ifelse.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ function add_andblock!(ls::LoopSet, condexpr::Expr, condeval::Expr, elementbytes
5858
condop = add_operation!(ls, gensym(:mask), condexpr, elementbytes, position)
5959
if condeval.head === :call
6060
@assert first(condeval.args) === :setindex!
61-
array, raw_indices = ref_from_setindex(condeval)
61+
array, raw_indices = ref_from_setindex!(ls, condeval)
6262
ref = Expr(:ref, array, raw_indices...)
6363
return add_andblock!(ls, condop, ref, condeval.args[3], elementbytes, position)
6464
end
@@ -99,7 +99,7 @@ function add_orblock!(ls::LoopSet, condexpr::Expr, condeval::Expr, elementbytes:
9999
condop = add_operation!(ls, gensym(:mask), condexpr, elementbytes, position)
100100
if condeval.head === :call
101101
@assert first(condeval.args) === :setindex!
102-
array, raw_indices = ref_from_setindex(condeval)
102+
array, raw_indices = ref_from_setindex!(ls, condeval)
103103
return add_orblock!(ls, condop, Expr(:ref, array, raw_indices...), condeval.args[3], elementbytes, position)
104104
end
105105
@assert condeval.head === :(=)

src/add_loads.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,11 @@ function add_simple_load!(
5252
pushop!(ls, op, var)
5353
end
5454
function add_load_ref!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int)
55-
array, rawindices = ref_from_ref(ex)
55+
array, rawindices = ref_from_ref!(ls, ex)
5656
add_load!(ls, var, array, rawindices, elementbytes)
5757
end
5858
function add_load_getindex!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int)
59-
array, rawindices = ref_from_getindex(ex)
59+
array, rawindices = ref_from_getindex!(ls, ex)
6060
add_load!(ls, var, array, rawindices, elementbytes)
6161
end
6262

src/add_stores.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ function add_simple_store!(ls::LoopSet, var::Symbol, ref::ArrayReference, elemen
6868
add_simple_store!(ls, getop(ls, var, elementbytes), ref, elementbytes)
6969
end
7070
function add_store_ref!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int)
71-
array, raw_indices = ref_from_ref(ex)
71+
array, raw_indices = ref_from_ref!(ls, ex)
7272
add_store!(ls, var, array, raw_indices, elementbytes)
7373
end
7474
function add_store_ref!(ls::LoopSet, var, ex::Expr, elementbytes::Int)
@@ -80,14 +80,14 @@ function add_store_ref!(ls::LoopSet, var, ex::Expr, elementbytes::Int)
8080
add_store_ref!(ls, name(c), ex, elementbytes)
8181
end
8282
function add_store_setindex!(ls::LoopSet, ex::Expr, elementbytes::Int)
83-
array, raw_indices = ref_from_setindex(ex)
83+
array, raw_indices = ref_from_setindex!(ls, ex)
8484
add_store!(ls, (ex.args[3])::Symbol, array, raw_indices, elementbytes)
8585
end
8686

8787
# For now, it is illegal to load from a conditional store.
8888
# if you want that sort of behavior, do a conditional reassignment, and store that result unconditionally.
8989
function add_conditional_store!(ls::LoopSet, LHS, condop::Operation, storeop::Operation, elementbytes::Int)
90-
array, rawindices = ref_from_ref(LHS)
90+
array, rawindices = ref_from_ref!(ls, LHS)
9191
mpref = array_reference_meta!(ls, array, rawindices, elementbytes)
9292
mref = mpref.mref
9393
ldref = mpref.loopdependencies

src/condense_loopset.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,6 @@ function setup_call(ls::LoopSet, q = nothing, inline::Bool = true, u₁ = zero(I
450450
else
451451
setup_call_noinline(ls, u₁, u₂)
452452
end
453-
isnothing(q) && return call
454-
Expr(:if, check_args_call(ls), call, q)
453+
isnothing(q) && return Expr(:block, ls.prepreamble, call)
454+
Expr(:block, ls.prepreamble, Expr(:if, check_args_call(ls), call, q))
455455
end

src/graphs.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ struct LoopSet
167167
outer_reductions::Vector{Int} # IDs of reduction operations that need to be reduced at end.
168168
loop_order::LoopOrder
169169
preamble::Expr
170+
prepreamble::Expr # performs extractions that must be performed first, and don't need further registering
170171
preamble_symsym::Vector{Tuple{Int,Symbol}}
171172
preamble_symint::Vector{Tuple{Int,Int}}
172173
preamble_symfloat::Vector{Tuple{Int,Float64}}
@@ -206,7 +207,7 @@ function save_tilecost!(ls::LoopSet)
206207
end
207208
end
208209

209-
210+
pushprepreamble!(ls::LoopSet, ex) = push!(ls.prepreamble.args, ex)
210211
function pushpreamble!(ls::LoopSet, op::Operation, v::Symbol)
211212
if v !== mangledvar(op)
212213
push!(ls.preamble_symsym, (identifier(op),v))
@@ -270,7 +271,7 @@ function LoopSet(mod::Symbol)
270271
Operation[], [0],
271272
Int[],
272273
LoopOrder(),
273-
Expr(:block),
274+
Expr(:block),Expr(:block),
274275
Tuple{Int,Symbol}[],
275276
Tuple{Int,Int}[],
276277
Tuple{Int,Float64}[],
@@ -511,7 +512,7 @@ function add_operation!(
511512
ls::LoopSet, LHS_sym::Symbol, RHS::Expr, LHS_ref::ArrayReferenceMetaPosition, elementbytes::Int, position::Int
512513
)
513514
if RHS.head === :ref# || (RHS.head === :call && first(RHS.args) === :getindex)
514-
array, rawindices = ref_from_expr(RHS)
515+
array, rawindices = ref_from_expr!(ls, RHS)
515516
RHS_ref = array_reference_meta!(ls, array, rawindices, elementbytes, gensym(LHS_sym))
516517
op = add_load!(ls, RHS_ref, elementbytes)
517518
iop = add_compute!(ls, LHS_sym, :identity, [op], elementbytes)
@@ -559,7 +560,7 @@ function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int, position::Int)
559560
elseif RHS isa Expr
560561
# need to check if LHS appears in RHS
561562
# assign RHS to lrhs
562-
array, rawindices = ref_from_expr(LHS)
563+
array, rawindices = ref_from_expr!(ls, LHS)
563564
mpref = array_reference_meta!(ls, array, rawindices, elementbytes)
564565
cachedparents = copy(mpref.parents)
565566
ref = mpref.mref.ref

src/lower_compute.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,8 @@ function lower_compute!(
163163
# want to instcombine when parent load's deps are superset
164164
# also make sure opp is unrolled
165165
if instrfid !== nothing && (opunrolled && u₁ > 1) && !load_constrained(op, u₁loopsym, u₂loopsym)
166-
specific_fmas = Base.libllvm_version > v"9.0.0" ? (:vfmadd, :vfnmadd, :vfmsub, :vfnmsub) : (:vfmadd231, :vfnmadd231, :vfmsub231, :vfnmsub231)
167-
# specific_fmas = (:vfmadd231, :vfnmadd231, :vfmsub231, :vfnmsub231)
166+
# specific_fmas = Base.libllvm_version > v"9.0.0" ? (:vfmadd, :vfnmadd, :vfmsub, :vfnmsub) : (:vfmadd231, :vfnmadd231, :vfmsub231, :vfnmsub231)
167+
specific_fmas = (:vfmadd231, :vfnmadd231, :vfmsub231, :vfnmsub231)
168168
instr = Instruction(specific_fmas[instrfid])
169169
end
170170
end

src/memory_ops_common.jl

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,30 @@
1-
function ref_from_expr(ex, offset1::Int, offset2::Int)
2-
(ex.args[1 + offset1])::Symbol, @view(ex.args[2 + offset2:end])
1+
function extract_array_symbol_from_ref!(ls::LoopSet, ex::Expr, offset1::Int)::Symbol
2+
ar = ex.args[1 + offset1]
3+
if isa(ar, Symbol)
4+
return ar
5+
elseif isa(ar, Expr) && ar.head === :(.)
6+
s = gensym(:extractedarray)
7+
pushprepreamble!(ls, Expr(:(=), s, ar))
8+
return s
9+
else
10+
throw("Indexing into the following expression was not recognized: $ar")
11+
end
12+
end
13+
14+
15+
function ref_from_expr!(ls, ex, offset1::Int, offset2::Int)
16+
ar = extract_array_symbol_from_ref!(ls, ex, offset1)
17+
ar, @view(ex.args[2 + offset2:end])
318
end
4-
ref_from_ref(ex::Expr) = ref_from_expr(ex, 0, 0)
5-
ref_from_getindex(ex::Expr) = ref_from_expr(ex, 1, 1)
6-
ref_from_setindex(ex::Expr) = ref_from_expr(ex, 1, 2)
7-
function ref_from_expr(ex::Expr)
19+
ref_from_ref!(ls::LoopSet, ex::Expr) = ref_from_expr!(ls, ex, 0, 0)
20+
ref_from_getindex!(ls::LoopSet, ex::Expr) = ref_from_expr!(ls, ex, 1, 1)
21+
ref_from_setindex!(ls::LoopSet, ex::Expr) = ref_from_expr!(ls, ex, 1, 2)
22+
function ref_from_expr!(ls::LoopSet, ex::Expr)
823
if ex.head === :ref
9-
ref_from_ref(ex)
24+
ref_from_ref!(ls, ex)
1025
else#if ex.head === :call
1126
f = first(ex.args)::Symbol
12-
f === :getindex ? ref_from_getindex(ex) : ref_from_setindex(ex)
27+
f === :getindex ? ref_from_getindex!(ls, ex) : ref_from_setindex!(ls, ex)
1328
end
1429
end
1530

@@ -159,13 +174,13 @@ function array_reference_meta!(ls::LoopSet, array::Symbol, rawindices, elementby
159174
end
160175
function tryrefconvert(ls::LoopSet, ex::Expr, elementbytes::Int, var::Union{Nothing,Symbol} = nothing)::Tuple{Bool,ArrayReferenceMetaPosition}
161176
ya, yinds = if ex.head === :ref
162-
ref_from_ref(ex)
177+
ref_from_ref!(ls, ex)
163178
elseif ex.head === :call
164179
f = first(ex.args)
165180
if f === :getindex
166-
ref_from_getindex(ex)
181+
ref_from_getindex!(ls, ex)
167182
elseif f === :setindex!
168-
ref_from_setindex(ex)
183+
ref_from_setindex!(ls, ex)
169184
else
170185
return false, NOTAREFERENCEMP
171186
end

src/user_api_conveniences.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
2+
# This is a convenience function for libraries like MaBLAS and PaddedMatrices to use explicitly, and for others to force more type inference in precompilation.
3+
let GEMMLOOPSET = LoopVectorization.LoopSet(
4+
:(for m 1:size(A,1), n 1:size(B,2)
5+
Cₘₙ = zero(eltype(C))
6+
for k 1:size(A,2)
7+
Cₘₙ += A[m,k] * B[k,n]
8+
end
9+
C[m,n] += Cₘₙ
10+
end)
11+
);
12+
order = LoopVectorization.choose_order(GEMMLOOPSET)
13+
mr = order[5]
14+
nr = last(order)
15+
@eval const mᵣ = $mr
16+
@eval const nᵣ = $nr
17+
end
18+

test/gemm.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
@testset "GEMM" begin
22
# using LoopVectorization, LinearAlgebra, Test; T = Float64
33
Unum, Tnum = LoopVectorization.VectorizationBase.REGISTER_COUNT == 16 ? (3, 4) : (5, 5)
4+
@test LoopVectorization.mᵣ == Unum
5+
@test LoopVectorization.nᵣ == Tnum
46
AmulBtq1 = :(for m 1:size(A,1), n 1:size(B,2)
57
C[m,n] = zeroB
68
for k 1:size(A,2)

test/miscellaneous.jl

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,31 @@ using Test
619619
c_re
620620
end
621621

622-
622+
struct MatHolder{T}
623+
d :: Vector{T}
624+
Wt :: Matrix{T}
625+
WtD :: Matrix{T}
626+
end
627+
628+
function MatCalcWtD(m::MatHolder)
629+
l, n = size(m.Wt)
630+
@avx for j in 1:n
631+
for i in 1:l
632+
m.WtD[i, j] = m.Wt[i, j] * m.d[j]
633+
end
634+
end
635+
end
636+
function MatHolder(
637+
d :: Vector{T},
638+
Wt :: Matrix{T}
639+
) where {T}
640+
l, n = size(Wt)
641+
@assert length(d) == n
642+
WtD = Matrix{T}(undef, l, n)
643+
MatHolder(d, Wt, WtD)
644+
end
645+
646+
623647
for T (Float32, Float64)
624648
@show T, @__LINE__
625649
A = randn(T, 199, 498);
@@ -804,6 +828,12 @@ using Test
804828
multiple_unrolls_split_depchains!(c_re_1, a_re, b_re, a_im, b_im) # [1 1; 1 1]
805829
multiple_unrolls_split_depchains_avx!(c_re_2, a_re, b_re, a_im, b_im) # [1 1; 1 1]
806830
@test c_re_1 c_re_2
831+
832+
mh = MatHolder(rand(T, 23), rand(T, 15,23));
833+
MatCalcWtD(mh)
834+
@test mh.WtD mh.Wt .* mh.d'
835+
836+
807837
end
808838
for T [Int16, Int32, Int64]
809839
n = 8sizeof(T) - 1

0 commit comments

Comments
 (0)