Skip to content

Commit 8f2e313

Browse files
committed
fix getindex bug
1 parent 5d9754a commit 8f2e313

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

src/linalg/mul.jl

+5-2
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,12 @@ getindex(M::MatMulMat, kj::CartesianIndex{2}) = M[kj[1], kj[2]]
166166
# MulArray
167167
#####
168168

169+
_mul(A) = A
170+
_mul(A,B,C...) = Mul(A,B,C...)
171+
169172
function getindex(M::Mul, k)
170173
A,Bs = first(M.factors), tail(M.factors)
171-
B = Mul(Bs)
174+
B = _mul(Bs...)
172175
ret = zero(eltype(M))
173176
for j = rowsupport(A, k)
174177
ret += A[k,j] * B[j]
@@ -178,7 +181,7 @@ end
178181

179182
function getindex(M::Mul, k, j)
180183
A,Bs = first(M.factors), tail(M.factors)
181-
B = Mul(Bs)
184+
B = _mul(Bs...)
182185
ret = zero(eltype(M))
183186
@inbounds forin (rowsupport(A,k) colsupport(B,j))
184187
ret += A[k,ℓ] * B[ℓ,j]

test/multests.jl

+9
Original file line numberDiff line numberDiff line change
@@ -672,9 +672,18 @@ import Base.Broadcast: materialize, materialize!
672672
M = MulArray(A,A)
673673
@test Matrix(M) A^2
674674
end
675+
676+
@testset "Bug in getindex" begin
677+
M = MulArray([1,2,3],Ones(1,20))
678+
@test M[1,1] == 1
679+
@test M[2,1] == 2
680+
end
675681
end
676682

677683

684+
685+
686+
678687
@testset "Add" begin
679688
@testset "gemv Float64" begin
680689
for A in (Add(randn(5,5), randn(5,5)),

0 commit comments

Comments
 (0)