Skip to content

Commit c21c174

Browse files
committed
A few bug fixes and added tests.
1 parent 9462c41 commit c21c174

File tree

5 files changed

+105
-6
lines changed

5 files changed

+105
-6
lines changed

β€Žsrc/add_ifelse.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ function add_if!(ls::LoopSet, LHS::Symbol, RHS::Expr, elementbytes::Int, positio
1717
end
1818
iftrue = RHS.args[2]
1919
if iftrue isa Expr
20-
trueop = add_operation!(ls, Symbol(:iftrue), iftrue, elementbytes, position)
20+
trueop = add_operation!(ls, gensym(:iftrue), iftrue, elementbytes, position)
2121
if iftrue.head === :ref && all(ld -> ld ∈ loopdependencies(trueop), loopdependencies(condop)) && !search_tree(parents(condop), trueop)
2222
trueop.instruction = Instruction(:conditionalload)
2323
push!(parents(trueop), condop)
@@ -27,7 +27,7 @@ function add_if!(ls::LoopSet, LHS::Symbol, RHS::Expr, elementbytes::Int, positio
2727
end
2828
iffalse = RHS.args[3]
2929
if iffalse isa Expr
30-
falseop = add_operation!(ls, Symbol(:iffalse), iffalse, elementbytes, position)
30+
falseop = add_operation!(ls, gensym(:iffalse), iffalse, elementbytes, position)
3131
if iffalse.head === :ref && all(ld -> ld ∈ loopdependencies(falseop), loopdependencies(condop)) && !search_tree(parents(condop), falseop)
3232
falseop.instruction = Instruction(:conditionalload)
3333
push!(parents(falseop), negateop!(ls, condop, elementbytes))

β€Žsrc/constructors.jl

+10
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
11

22
### This file contains convenience functions for constructing LoopSets.
33

4+
# function strip_unneeded_const_deps!(ls::LoopSet)
5+
# for op ∈ operations(ls)
6+
# if isconstant(op) && iszero(length(reducedchildren(op)))
7+
# op.dependencies = NODEPENDENCY
8+
# end
9+
# end
10+
# ls
11+
# end
12+
413
function Base.copyto!(ls::LoopSet, q::Expr)
514
q.head === :for || throw("Expression must be a for loop.")
615
add_loop!(ls, q, 8)
16+
# strip_unneeded_const_deps!(ls)
717
end
818

919
function add_ci_call!(q::Expr, @nospecialize(f), args, syms, i, mod = nothing)

β€Žsrc/graphs.jl

+8-4
Original file line numberDiff line numberDiff line change
@@ -555,9 +555,10 @@ end
555555
instruction!(ls::LoopSet, x::Symbol) = instruction(x)
556556

557557

558-
function maybe_const_compute!(ls::LoopSet, op::Operation, elementbytes::Int, position::Int)
558+
function maybe_const_compute!(ls::LoopSet, LHS::Symbol, op::Operation, elementbytes::Int, position::Int)
559+
# return op
559560
if iscompute(op) && iszero(length(loopdependencies(op)))
560-
add_constant!(ls, mangledvar(op), ls.loopsymbols[1:position], gensym(instruction(op).instr), elementbytes, :numericconstant)
561+
ls.opdict[LHS] = add_constant!(ls, LHS, ls.loopsymbols[1:position], gensym(instruction(op).instr), elementbytes, :numericconstant)
561562
else
562563
op
563564
end
@@ -572,6 +573,9 @@ function strip_op_linenumber_nodes(q::Expr)
572573
end
573574
end
574575

576+
function add_operation!(ls::LoopSet, LHS::Symbol, RHS::Symbol, elementbytes::Int, position::Int)
577+
add_constant!(ls, RHS, ls.loopsymbols[1:position], LHS, elementbytes)
578+
end
575579
function add_operation!(
576580
ls::LoopSet, LHS::Symbol, RHS::Expr, elementbytes::Int, position::Int
577581
)
@@ -670,7 +674,7 @@ function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int, position::Int)
670674
RHS = ex.args[2]
671675
if LHS isa Symbol
672676
if RHS isa Expr
673-
maybe_const_compute!(ls, add_operation!(ls, LHS, RHS, elementbytes, position), elementbytes, position)
677+
maybe_const_compute!(ls, LHS, add_operation!(ls, LHS, RHS, elementbytes, position), elementbytes, position)
674678
else
675679
add_constant!(ls, RHS, ls.loopsymbols[1:position], LHS, elementbytes)
676680
end
@@ -689,7 +693,7 @@ function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int, position::Int)
689693
elseif LHS.head === :tuple
690694
@assert length(LHS.args) ≀ 9 "Functions returning more than 9 values aren't currently supported."
691695
lhstemp = gensym(:lhstuple)
692-
vparents = Operation[maybe_const_compute!(ls, add_operation!(ls, lhstemp, RHS, elementbytes, position), elementbytes, position)]
696+
vparents = Operation[maybe_const_compute!(ls, lhstemp, add_operation!(ls, lhstemp, RHS, elementbytes, position), elementbytes, position)]
693697
for i ∈ eachindex(LHS.args)
694698
f = (:first,:second,:third,:fourth,:fifth,:sixth,:seventh,:eighth,:ninth)[i]
695699
lhsi = LHS.args[i]

β€Žtest/ifelsemasks.jl

+49
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,36 @@ T = Float32
317317
x[i] = yα΅’ * zα΅’
318318
end
319319
end
320+
321+
function twoifelses!(res, half, m, keep=nothing, final=true)
322+
𝒢𝓍j=axes(half,1)
323+
for j in 𝒢𝓍j
324+
π“‡π’½π“ˆ = if isnothing(keep)
325+
log(half[j]) + m[j]
326+
else
327+
res[j] + (log(half[j]) + m[j])
328+
end
329+
res[j] = isnothing(final) ? π“‡π’½π“ˆ : exp(π“‡π’½π“ˆ)
330+
end
331+
res
332+
end
333+
function twoifelses_avx!(res, half, m, keep=nothing, final=true)
334+
𝒢𝓍j=axes(half,1)
335+
@avx for j in 𝒢𝓍j
336+
π“‡π’½π“ˆ = if isnothing(keep)
337+
log(half[j]) + m[j]
338+
else
339+
res[j] + (log(half[j]) + m[j])
340+
end
341+
res[j] = if isnothing(final)
342+
π“‡π’½π“ˆ
343+
else
344+
exp(π“‡π’½π“ˆ)
345+
end
346+
end
347+
res
348+
end
349+
320350
N = 117
321351
for T ∈ (Float32, Float64, Int32, Int64)
322352
@show T, @__LINE__
@@ -425,6 +455,25 @@ T = Float32
425455
condloadscalar!(C1, B, b, d)
426456
condloadscalaravx!(C2, B, b, d)
427457
@test C1 β‰ˆ C2
458+
459+
if T <: Integer
460+
half = rand(T(1):T(100), 7);
461+
m = rand(T(-10):T(10), 7);
462+
else
463+
half = rand(T, 7); m = rand(T, 7);
464+
end;
465+
if sizeof(T) == 4
466+
res1 = Vector{Float32}(undef, 7);
467+
res2 = Vector{Float32}(undef, 7);
468+
else
469+
res1 = Vector{Float64}(undef, 7);
470+
res2 = Vector{Float64}(undef, 7);
471+
end
472+
473+
for keep ∈ (nothing,true), final ∈ (nothing,true)
474+
@test twoifelses!(res1, half, m) β‰ˆ twoifelses_avx!(res2, half, m)
475+
end
476+
428477
end
429478

430479

β€Žtest/miscellaneous.jl

+36
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,38 @@ function manyreturntestavx(x)
752752
s
753753
end
754754

755+
function maybe_const_issue144!(π›₯mat, π›₯β„›, mat, β„›)
756+
π›₯β„›_value = π›₯β„›.value
757+
for j in axes(mat,2)
758+
for i in axes(mat,1)
759+
ℰ𝓍1 = conj(π›₯β„›_value) # could be outside both loops
760+
ℰ𝓍2 = -(β„›[j]) # could be outside i loop
761+
ℰ𝓍3 = exp(ℰ𝓍2) # could be outside i loop
762+
ℰ𝓍4 = exp(mat[i, j])
763+
ℰ𝓍5 = ℰ𝓍3 * ℰ𝓍4
764+
ℰ𝓍6 = ℰ𝓍1 * ℰ𝓍5
765+
ℰ𝓍7 = conj(ℰ𝓍6)
766+
π›₯mat[i, j] = π›₯mat[i, j] + ℰ𝓍7
767+
end
768+
end
769+
π›₯mat
770+
end
771+
function maybe_const_issue144_avx!(π›₯mat, π›₯β„›, mat, β„›)
772+
π›₯β„›_value = π›₯β„›.value
773+
@avx for j in axes(mat,2)
774+
for i in axes(mat,1)
775+
ℰ𝓍1 = conj(π›₯β„›_value)
776+
ℰ𝓍2 = -(β„›[j])
777+
ℰ𝓍3 = exp(ℰ𝓍2)
778+
ℰ𝓍4 = exp(mat[i, j])
779+
ℰ𝓍5 = ℰ𝓍3 * ℰ𝓍4
780+
ℰ𝓍6 = ℰ𝓍1 * ℰ𝓍5
781+
ℰ𝓍7 = conj(ℰ𝓍6)
782+
π›₯mat[i, j] = π›₯mat[i, j] + ℰ𝓍7
783+
end
784+
end
785+
π›₯mat
786+
end
755787

756788
for T ∈ (Float32, Float64)
757789
@show T, @__LINE__
@@ -970,6 +1002,10 @@ end
9701002

9711003
@test all(isequal(81), powcseliteral!(E0))
9721004
@test all(isequal(81), powcsesymbol!(E3))
1005+
1006+
@test maybe_const_issue144!(zeros(T, 3,4), (value=one(T),), collect(reshape(1:12, 3,4)), ones(T, 4)) β‰ˆ maybe_const_issue144_avx!(zeros(T,3,4), (value=one(T),), collect(reshape(1:12, 3,4)), ones(T,4))
1007+
1008+
9731009
end
9741010
for T ∈ [Int16, Int32, Int64]
9751011
n = 8sizeof(T) - 1

0 commit comments

Comments
Β (0)