@@ -57,11 +57,13 @@ function add_broadcast!(
57
57
bloopsyms = Symbol[k]
58
58
cloopsyms = Symbol[m]
59
59
reductdeps = Symbol[m, k]
60
+ kvec = bloopsyms
60
61
elseif ndims (B) == 2
61
62
n = loopsyms[2 ];
62
63
bloopsyms = Symbol[k,n]
63
64
cloopsyms = Symbol[m,n]
64
65
reductdeps = Symbol[m, k, n]
66
+ kvec = Symbol[k]
65
67
else
66
68
throw (" B must be a vector or matrix." )
67
69
end
@@ -72,13 +74,22 @@ function add_broadcast!(
72
74
loadB = add_broadcast! (ls, gensym (:B ), mB, bloopsyms, B, elementbytes)
73
75
# set Cₘₙ = 0
74
76
# setC = add_constant!(ls, zero(promote_type(recursive_eltype(A), recursive_eltype(B))), cloopsyms, mC, elementbytes)
77
+ # targetC will be used for reduce_to_add
78
+ mCt = gensym (mC)
79
+ targetC = add_constant! (ls, gensym (:zero ), cloopsyms, mCt, elementbytes, :numericconstant )
80
+ push! (ls. preamble_zeros, (identifier (targetC), IntOrFloat))
75
81
setC = add_constant! (ls, gensym (:zero ), cloopsyms, mC, elementbytes, :numericconstant )
76
82
push! (ls. preamble_zeros, (identifier (setC), IntOrFloat))
83
+ setC. reduced_children = kvec
77
84
# compute Cₘₙ += Aₘₖ * Bₖₙ
78
85
reductop = Operation (
79
- ls, mC, elementbytes, :vmuladd , compute, reductdeps, Symbol[k] , Operation[loadA, loadB, setC]
86
+ ls, mC, elementbytes, :vmuladd , compute, reductdeps, kvec , Operation[loadA, loadB, setC]
80
87
)
81
- pushop! (ls, reductop, mC)
88
+ reductop = pushop! (ls, reductop, mC)
89
+ reductfinal = Operation (
90
+ ls, mCt, elementbytes, :reduce_to_add , compute, cloopsyms, kvec, Operation[reductop, targetC]
91
+ )
92
+ pushop! (ls, reductfinal, mCt)
82
93
end
83
94
84
95
struct LowDimArray{D,T,N,A<: DenseArray{T,N} } <: DenseArray{T,N}
0 commit comments