Skip to content

Commit 4b481a3

Browse files
committed
Fix broadcasting products for general access patterns.
1 parent 75daf98 commit 4b481a3

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

src/broadcast.jl

+13-2
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,13 @@ function add_broadcast!(
5757
bloopsyms = Symbol[k]
5858
cloopsyms = Symbol[m]
5959
reductdeps = Symbol[m, k]
60+
kvec = bloopsyms
6061
elseif ndims(B) == 2
6162
n = loopsyms[2];
6263
bloopsyms = Symbol[k,n]
6364
cloopsyms = Symbol[m,n]
6465
reductdeps = Symbol[m, k, n]
66+
kvec = Symbol[k]
6567
else
6668
throw("B must be a vector or matrix.")
6769
end
@@ -72,13 +74,22 @@ function add_broadcast!(
7274
loadB = add_broadcast!(ls, gensym(:B), mB, bloopsyms, B, elementbytes)
7375
# set Cₘₙ = 0
7476
# setC = add_constant!(ls, zero(promote_type(recursive_eltype(A), recursive_eltype(B))), cloopsyms, mC, elementbytes)
77+
# targetC will be used for reduce_to_add
78+
mCt = gensym(mC)
79+
targetC = add_constant!(ls, gensym(:zero), cloopsyms, mCt, elementbytes, :numericconstant)
80+
push!(ls.preamble_zeros, (identifier(targetC), IntOrFloat))
7581
setC = add_constant!(ls, gensym(:zero), cloopsyms, mC, elementbytes, :numericconstant)
7682
push!(ls.preamble_zeros, (identifier(setC), IntOrFloat))
83+
setC.reduced_children = kvec
7784
# compute Cₘₙ += Aₘₖ * Bₖₙ
7885
reductop = Operation(
79-
ls, mC, elementbytes, :vmuladd, compute, reductdeps, Symbol[k], Operation[loadA, loadB, setC]
86+
ls, mC, elementbytes, :vmuladd, compute, reductdeps, kvec, Operation[loadA, loadB, setC]
8087
)
81-
pushop!(ls, reductop, mC)
88+
reductop = pushop!(ls, reductop, mC)
89+
reductfinal = Operation(
90+
ls, mCt, elementbytes, :reduce_to_add, compute, cloopsyms, kvec, Operation[reductop, targetC]
91+
)
92+
pushop!(ls, reductfinal, mCt)
8293
end
8394

8495
struct LowDimArray{D,T,N,A<:DenseArray{T,N}} <: DenseArray{T,N}

test/runtests.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -1255,10 +1255,12 @@ end
12551255

12561256
M, K, N = 77, 83, 57;
12571257
A = rand(R,M,K); B = rand(R,K,N); C = rand(R,M,N);
1258-
1258+
At = copy(A')
12591259
D1 = C .+ A * B;
12601260
D2 = @avx C .+ A *ˡ B;
12611261
@test D1 D2
1262+
fill!(D2, -999999); D2 = @avx C .+ At' *ˡ B;
1263+
@test D1 D2
12621264
if T <: Union{Float32,Float64}
12631265
D3 = cos.(B');
12641266
D4 = @avx cos.(B');

0 commit comments

Comments
 (0)