Skip to content

Commit 5e2574d

Browse files
authored
Add correct storage_type for lazy wrapped GPU arrays (adjoint, transpose, Diagonal) (#365)
* Add storage_type for adjoint, tranpose and Diagonal s.t. they work for GPU arrays * Make lazy wrapper tests for GPUs more consistent
1 parent 7013aa6 commit 5e2574d

File tree

3 files changed

+16
-0
lines changed

3 files changed

+16
-0
lines changed

src/abstract.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,11 @@ storage_type(op::AbstractLinearOperator) = error("please implement storage_type
174174
storage_type(op::LinearOperator) = typeof(op.Mv5)
175175
storage_type(M::AbstractMatrix{T}) where {T} = Vector{T}
176176

177+
# Lazy wrappers
178+
storage_type(op::Adjoint) = storage_type(parent(op))
179+
storage_type(op::Transpose) = storage_type(parent(op))
180+
storage_type(op::Diagonal) = typeof(parent(op))
181+
177182
"""
178183
reset!(op)
179184

test/gpu/amdgpu.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,10 @@ using LinearOperators, AMDGPU
1111
y = M * v
1212
@test y isa ROCArray{Float32}
1313

14+
@test LinearOperators.storage_type(A) == LinearOperators.storage_type(adjoint(A))
15+
@test LinearOperators.storage_type(A) == LinearOperators.storage_type(transpose(A))
16+
@test LinearOperators.storage_type(A) == LinearOperators.storage_type(adjoint(A))
17+
@test LinearOperators.storage_type(Diagonal(v)) == typeof(v)
18+
1419
@testset "AMDGPU S kwarg" test_S_kwarg(arrayType = ROCArray)
1520
end

test/gpu/nvidia.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,11 @@ using LinearOperators, CUDA, CUDA.CUSPARSE, CUDA.CUSOLVER
1313
v = CUDA.rand(35)
1414
y = M * v
1515
@test y isa CuVector{Float32}
16+
17+
@test LinearOperators.storage_type(A) == LinearOperators.storage_type(adjoint(A))
18+
@test LinearOperators.storage_type(A) == LinearOperators.storage_type(transpose(A))
19+
@test LinearOperators.storage_type(A) == LinearOperators.storage_type(adjoint(A))
20+
@test LinearOperators.storage_type(Diagonal(v)) == typeof(v)
21+
1622
@testset "Nvidia S kwarg" test_S_kwarg(arrayType = CuArray)
1723
end

0 commit comments

Comments
 (0)