Skip to content

Commit 2f90dde

Browse files
tkfmbauman
andauthored
Faster mapreduce for Broadcasted (#31020)
* Better mapreduce for Broadcasted * Use Axes in IndexStyle for Broadcasted * Apply suggestions from code review Co-Authored-By: tkf <29282+tkf@users.noreply.github.com> * Update base/broadcast.jl Co-Authored-By: tkf <29282+tkf@users.noreply.github.com> * Fix IndexStyle for IndexLinear case * Fix LinearIndices for Broadcasted * Test that pairwise mapreduce is used * Test count(::Broadcasted) * Support Broadcasted in mapreducedim! Co-authored-by: Matt Bauman <mbauman@gmail.com>
1 parent ae2063f commit 2f90dde

File tree

4 files changed

+116
-33
lines changed

4 files changed

+116
-33
lines changed

base/broadcast.jl

+26-10
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ BroadcastStyle(a::AbstractArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} =
166166
# methods that instead specialize on `BroadcastStyle`,
167167
# copyto!(dest::AbstractArray, bc::Broadcasted{MyStyle})
168168

169-
struct Broadcasted{Style<:Union{Nothing,BroadcastStyle}, Axes, F, Args<:Tuple}
169+
struct Broadcasted{Style<:Union{Nothing,BroadcastStyle}, Axes, F, Args<:Tuple} <: Base.AbstractBroadcasted
170170
f::F
171171
args::Args
172172
axes::Axes # the axes of the resulting object (may be bigger than implied by `args` if this is nested inside a larger `Broadcasted`)
@@ -193,21 +193,25 @@ function Base.show(io::IO, bc::Broadcasted{Style}) where {Style}
193193
end
194194

195195
## Allocating the output container
196-
Base.similar(bc::Broadcasted{DefaultArrayStyle{N}}, ::Type{ElType}) where {N,ElType} =
197-
similar(Array{ElType}, axes(bc))
198-
Base.similar(bc::Broadcasted{DefaultArrayStyle{N}}, ::Type{Bool}) where N =
199-
similar(BitArray, axes(bc))
196+
Base.similar(bc::Broadcasted, ::Type{T}) where {T} = similar(bc, T, axes(bc))
197+
Base.similar(::Broadcasted{DefaultArrayStyle{N}}, ::Type{ElType}, dims) where {N,ElType} =
198+
similar(Array{ElType}, dims)
199+
Base.similar(::Broadcasted{DefaultArrayStyle{N}}, ::Type{Bool}, dims) where N =
200+
similar(BitArray, dims)
200201
# In cases of conflict we fall back on Array
201-
Base.similar(bc::Broadcasted{ArrayConflict}, ::Type{ElType}) where ElType =
202-
similar(Array{ElType}, axes(bc))
203-
Base.similar(bc::Broadcasted{ArrayConflict}, ::Type{Bool}) =
204-
similar(BitArray, axes(bc))
202+
Base.similar(::Broadcasted{ArrayConflict}, ::Type{ElType}, dims) where ElType =
203+
similar(Array{ElType}, dims)
204+
Base.similar(::Broadcasted{ArrayConflict}, ::Type{Bool}, dims) =
205+
similar(BitArray, dims)
205206

206207
@inline Base.axes(bc::Broadcasted) = _axes(bc, bc.axes)
207208
_axes(::Broadcasted, axes::Tuple) = axes
208209
@inline _axes(bc::Broadcasted, ::Nothing) = combine_axes(bc.args...)
209210
_axes(bc::Broadcasted{<:AbstractArrayStyle{0}}, ::Nothing) = ()
210211

212+
@inline Base.axes(bc::Broadcasted{<:Any, <:NTuple{N}}, d::Integer) where N =
213+
d <= N ? axes(bc)[d] : OneTo(1)
214+
211215
BroadcastStyle(::Type{<:Broadcasted{Style}}) where {Style} = Style()
212216
BroadcastStyle(::Type{<:Broadcasted{S}}) where {S<:Union{Nothing,Unknown}} =
213217
throw(ArgumentError("Broadcasted{Unknown} wrappers do not have a style assigned"))
@@ -219,6 +223,12 @@ argtype(bc::Broadcasted) = argtype(typeof(bc))
219223
_eachindex(t::Tuple{Any}) = t[1]
220224
_eachindex(t::Tuple) = CartesianIndices(t)
221225

226+
Base.IndexStyle(bc::Broadcasted) = IndexStyle(typeof(bc))
227+
Base.IndexStyle(::Type{<:Broadcasted{<:Any,<:Tuple{Any}}}) = IndexLinear()
228+
Base.IndexStyle(::Type{<:Broadcasted{<:Any}}) = IndexCartesian()
229+
230+
Base.LinearIndices(bc::Broadcasted{<:Any,<:Tuple{Any}}) = axes(bc)[1]
231+
222232
Base.ndims(::Broadcasted{<:Any,<:NTuple{N,Any}}) where {N} = N
223233
Base.ndims(::Type{<:Broadcasted{<:Any,<:NTuple{N,Any}}}) where {N} = N
224234

@@ -564,7 +574,13 @@ end
564574
@boundscheck checkbounds(bc, I)
565575
@inbounds _broadcast_getindex(bc, I)
566576
end
567-
Base.@propagate_inbounds Base.getindex(bc::Broadcasted, i1::Integer, i2::Integer, I::Integer...) = bc[CartesianIndex((i1, i2, I...))]
577+
Base.@propagate_inbounds Base.getindex(
578+
bc::Broadcasted,
579+
i1::Union{Integer,CartesianIndex},
580+
i2::Union{Integer,CartesianIndex},
581+
I::Union{Integer,CartesianIndex}...,
582+
) =
583+
bc[CartesianIndex((i1, i2, I...))]
568584
Base.@propagate_inbounds Base.getindex(bc::Broadcasted) = bc[CartesianIndex(())]
569585

570586
@inline Base.checkbounds(bc::Broadcasted, I::Union{Integer,CartesianIndex}) =

base/reduce.jl

+12-8
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ else
1212
const SmallUnsigned = Union{UInt8,UInt16,UInt32}
1313
end
1414

15+
abstract type AbstractBroadcasted end
16+
const AbstractArrayOrBroadcasted = Union{AbstractArray, AbstractBroadcasted}
17+
1518
"""
1619
Base.add_sum(x, y)
1720
@@ -227,7 +230,8 @@ foldr(op, itr; kw...) = mapfoldr(identity, op, itr; kw...)
227230

228231
# This is a generic implementation of `mapreduce_impl()`,
229232
# certain `op` (e.g. `min` and `max`) may have their own specialized versions.
230-
@noinline function mapreduce_impl(f, op, A::AbstractArray, ifirst::Integer, ilast::Integer, blksize::Int)
233+
@noinline function mapreduce_impl(f, op, A::AbstractArrayOrBroadcasted,
234+
ifirst::Integer, ilast::Integer, blksize::Int)
231235
if ifirst == ilast
232236
@inbounds a1 = A[ifirst]
233237
return mapreduce_first(f, op, a1)
@@ -250,7 +254,7 @@ foldr(op, itr; kw...) = mapfoldr(identity, op, itr; kw...)
250254
end
251255
end
252256

253-
mapreduce_impl(f, op, A::AbstractArray, ifirst::Integer, ilast::Integer) =
257+
mapreduce_impl(f, op, A::AbstractArrayOrBroadcasted, ifirst::Integer, ilast::Integer) =
254258
mapreduce_impl(f, op, A, ifirst, ilast, pairwise_blocksize(f, op))
255259

256260
"""
@@ -383,13 +387,13 @@ The default is `reduce_first(op, f(x))`.
383387
"""
384388
mapreduce_first(f, op, x) = reduce_first(op, f(x))
385389

386-
_mapreduce(f, op, A::AbstractArray) = _mapreduce(f, op, IndexStyle(A), A)
390+
_mapreduce(f, op, A::AbstractArrayOrBroadcasted) = _mapreduce(f, op, IndexStyle(A), A)
387391

388-
function _mapreduce(f, op, ::IndexLinear, A::AbstractArray{T}) where T
392+
function _mapreduce(f, op, ::IndexLinear, A::AbstractArrayOrBroadcasted)
389393
inds = LinearIndices(A)
390394
n = length(inds)
391395
if n == 0
392-
return mapreduce_empty(f, op, T)
396+
return mapreduce_empty_iter(f, op, A, IteratorEltype(A))
393397
elseif n == 1
394398
@inbounds a1 = A[first(inds)]
395399
return mapreduce_first(f, op, a1)
@@ -410,7 +414,7 @@ end
410414

411415
mapreduce(f, op, a::Number) = mapreduce_first(f, op, a)
412416

413-
_mapreduce(f, op, ::IndexCartesian, A::AbstractArray) = mapfoldl(f, op, A)
417+
_mapreduce(f, op, ::IndexCartesian, A::AbstractArrayOrBroadcasted) = mapfoldl(f, op, A)
414418

415419
"""
416420
reduce(op, itr; [init])
@@ -560,7 +564,7 @@ isgoodzero(::typeof(max), x) = isbadzero(min, x)
560564
isgoodzero(::typeof(min), x) = isbadzero(max, x)
561565

562566
function mapreduce_impl(f, op::Union{typeof(max), typeof(min)},
563-
A::AbstractArray, first::Int, last::Int)
567+
A::AbstractArrayOrBroadcasted, first::Int, last::Int)
564568
a1 = @inbounds A[first]
565569
v1 = mapreduce_first(f, op, a1)
566570
v2 = v3 = v4 = v1
@@ -856,7 +860,7 @@ function count(pred, itr)
856860
end
857861
return n
858862
end
859-
function count(pred, a::AbstractArray)
863+
function count(pred, a::AbstractArrayOrBroadcasted)
860864
n = 0
861865
for i in eachindex(a)
862866
@inbounds n += pred(a[i])::Bool

base/reducedim.jl

+22-15
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ No method is implemented for reducing index range of type $(typeof(i)). Please i
1212
reduced_index for this index type or report this as an issue.
1313
"""
1414
))
15-
reduced_indices(a::AbstractArray, region) = reduced_indices(axes(a), region)
15+
reduced_indices(a::AbstractArrayOrBroadcasted, region) = reduced_indices(axes(a), region)
1616

1717
# for reductions that keep 0 dims as 0
1818
reduced_indices0(a::AbstractArray, region) = reduced_indices0(axes(a), region)
@@ -89,8 +89,8 @@ for (Op, initval) in ((:(typeof(&)), true), (:(typeof(|)), false))
8989
end
9090

9191
# reducedim_initarray is called by
92-
reducedim_initarray(A::AbstractArray, region, init, ::Type{R}) where {R} = fill!(similar(A,R,reduced_indices(A,region)), init)
93-
reducedim_initarray(A::AbstractArray, region, init::T) where {T} = reducedim_initarray(A, region, init, T)
92+
reducedim_initarray(A::AbstractArrayOrBroadcasted, region, init, ::Type{R}) where {R} = fill!(similar(A,R,reduced_indices(A,region)), init)
93+
reducedim_initarray(A::AbstractArrayOrBroadcasted, region, init::T) where {T} = reducedim_initarray(A, region, init, T)
9494

9595
# TODO: better way to handle reducedim initialization
9696
#
@@ -156,8 +156,8 @@ end
156156
reducedim_init(f::Union{typeof(abs),typeof(abs2)}, op::typeof(max), A::AbstractArray{T}, region) where {T} =
157157
reducedim_initarray(A, region, zero(f(zero(T))), _realtype(f, T))
158158

159-
reducedim_init(f, op::typeof(&), A::AbstractArray, region) = reducedim_initarray(A, region, true)
160-
reducedim_init(f, op::typeof(|), A::AbstractArray, region) = reducedim_initarray(A, region, false)
159+
reducedim_init(f, op::typeof(&), A::AbstractArrayOrBroadcasted, region) = reducedim_initarray(A, region, true)
160+
reducedim_init(f, op::typeof(|), A::AbstractArrayOrBroadcasted, region) = reducedim_initarray(A, region, false)
161161

162162
# specialize to make initialization more efficient for common cases
163163

@@ -179,8 +179,11 @@ end
179179

180180
## generic (map)reduction
181181

182-
has_fast_linear_indexing(a::AbstractArray) = false
182+
has_fast_linear_indexing(a::AbstractArrayOrBroadcasted) = false
183183
has_fast_linear_indexing(a::Array) = true
184+
has_fast_linear_indexing(::Number) = true # for Broadcasted
185+
has_fast_linear_indexing(bc::Broadcast.Broadcasted) =
186+
all(has_fast_linear_indexing, bc.args)
184187

185188
function check_reducedims(R, A)
186189
# Check whether R has compatible dimensions w.r.t. A for reduction
@@ -233,7 +236,7 @@ _firstslice(i::OneTo) = OneTo(1)
233236
_firstslice(i::Slice) = Slice(_firstslice(i.indices))
234237
_firstslice(i) = i[firstindex(i):firstindex(i)]
235238

236-
function _mapreducedim!(f, op, R::AbstractArray, A::AbstractArray)
239+
function _mapreducedim!(f, op, R::AbstractArray, A::AbstractArrayOrBroadcasted)
237240
lsiz = check_reducedims(R,A)
238241
isempty(A) && return R
239242

@@ -271,10 +274,10 @@ function _mapreducedim!(f, op, R::AbstractArray, A::AbstractArray)
271274
return R
272275
end
273276

274-
mapreducedim!(f, op, R::AbstractArray, A::AbstractArray) =
277+
mapreducedim!(f, op, R::AbstractArray, A::AbstractArrayOrBroadcasted) =
275278
(_mapreducedim!(f, op, R, A); R)
276279

277-
reducedim!(op, R::AbstractArray{RT}, A::AbstractArray) where {RT} =
280+
reducedim!(op, R::AbstractArray{RT}, A::AbstractArrayOrBroadcasted) where {RT} =
278281
mapreducedim!(identity, op, R, A)
279282

280283
"""
@@ -304,17 +307,21 @@ julia> mapreduce(isodd, |, a, dims=1)
304307
1 1 1 1
305308
```
306309
"""
307-
mapreduce(f, op, A::AbstractArray; dims=:, kw...) = _mapreduce_dim(f, op, kw.data, A, dims)
308-
mapreduce(f, op, A::AbstractArray...; kw...) = reduce(op, map(f, A...); kw...)
310+
mapreduce(f, op, A::AbstractArrayOrBroadcasted; dims=:, kw...) =
311+
_mapreduce_dim(f, op, kw.data, A, dims)
312+
mapreduce(f, op, A::AbstractArrayOrBroadcasted...; kw...) =
313+
reduce(op, map(f, A...); kw...)
309314

310-
_mapreduce_dim(f, op, nt::NamedTuple{(:init,)}, A::AbstractArray, ::Colon) = mapfoldl(f, op, A; nt...)
315+
_mapreduce_dim(f, op, nt::NamedTuple{(:init,)}, A::AbstractArrayOrBroadcasted, ::Colon) =
316+
mapfoldl(f, op, A; nt...)
311317

312-
_mapreduce_dim(f, op, ::NamedTuple{()}, A::AbstractArray, ::Colon) = _mapreduce(f, op, IndexStyle(A), A)
318+
_mapreduce_dim(f, op, ::NamedTuple{()}, A::AbstractArrayOrBroadcasted, ::Colon) =
319+
_mapreduce(f, op, IndexStyle(A), A)
313320

314-
_mapreduce_dim(f, op, nt::NamedTuple{(:init,)}, A::AbstractArray, dims) =
321+
_mapreduce_dim(f, op, nt::NamedTuple{(:init,)}, A::AbstractArrayOrBroadcasted, dims) =
315322
mapreducedim!(f, op, reducedim_initarray(A, dims, nt.init), A)
316323

317-
_mapreduce_dim(f, op, ::NamedTuple{()}, A::AbstractArray, dims) =
324+
_mapreduce_dim(f, op, ::NamedTuple{()}, A::AbstractArrayOrBroadcasted, dims) =
318325
mapreducedim!(f, op, reducedim_init(f, op, A, dims), A)
319326

320327
"""

test/broadcast.jl

+56
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,7 @@ end
821821
# Broadcasted iterable/indexable APIs
822822
let
823823
bc = Broadcast.instantiate(Broadcast.broadcasted(+, zeros(5), 5))
824+
@test IndexStyle(bc) == IndexLinear()
824825
@test eachindex(bc) === Base.OneTo(5)
825826
@test length(bc) === 5
826827
@test ndims(bc) === 1
@@ -831,6 +832,7 @@ let
831832
@test ndims(copy(bc)) == ndims([v for v in bc]) == ndims(collect(bc)) == ndims(bc)
832833

833834
bc = Broadcast.instantiate(Broadcast.broadcasted(+, zeros(5), 5*ones(1, 4)))
835+
@test IndexStyle(bc) == IndexCartesian()
834836
@test eachindex(bc) === CartesianIndices((Base.OneTo(5), Base.OneTo(4)))
835837
@test length(bc) === 20
836838
@test ndims(bc) === 2
@@ -851,6 +853,60 @@ let a = rand(5), b = rand(5), c = copy(a)
851853
@test x == [2]
852854
end
853855

856+
@testset "broadcasted mapreduce" begin
857+
xs = 1:10
858+
ys = 1:2:20
859+
bc = Broadcast.instantiate(Broadcast.broadcasted(*, xs, ys))
860+
@test IndexStyle(bc) == IndexLinear()
861+
@test sum(bc) == mapreduce(Base.splat(*), +, zip(xs, ys))
862+
863+
xs2 = reshape(xs, 1, :)
864+
ys2 = reshape(ys, 1, :)
865+
bc = Broadcast.instantiate(Broadcast.broadcasted(*, xs2, ys2))
866+
@test IndexStyle(bc) == IndexCartesian()
867+
@test sum(bc) == mapreduce(Base.splat(*), +, zip(xs, ys))
868+
869+
xs = 1:5:3*5
870+
ys = 1:4:3*4
871+
bc = Broadcast.instantiate(
872+
Broadcast.broadcasted(iseven, Broadcast.broadcasted(-, xs, ys)))
873+
@test count(bc) == count(iseven, map(-, xs, ys))
874+
875+
xs = reshape(1:6, (2, 3))
876+
ys = 1:2
877+
bc = Broadcast.instantiate(Broadcast.broadcasted(*, xs, ys))
878+
@test reduce(+, bc; dims=1, init=0) == [5 11 17]
879+
880+
# Let's test that `Broadcasted` actually hits the efficient
881+
# `mapreduce` method as intended. We are going to invoke `reduce`
882+
# with this *NON-ASSOCIATIVE* binary operator to see what
883+
# associativity is chosen by the implementation:
884+
paren = (x, y) -> "($x,$y)"
885+
# Next, we construct data `xs` such that `length(xs)` is greater
886+
# than short array cutoff of `_mapreduce`:
887+
alphabets = 'a':'z'
888+
blksize = Base.pairwise_blocksize(identity, paren) ÷ length(alphabets)
889+
xs = repeat(alphabets, 2 * blksize)
890+
@test length(xs) > blksize
891+
# So far we constructed the data `xs` and reducing function
892+
# `paren` such that `reduce` and `foldl` results are different.
893+
# That is to say, this `reduce` does not hit the fall-back `foldl`
894+
# branch:
895+
@test foldl(paren, xs) != reduce(paren, xs)
896+
897+
# Now let's try it with `Broadcasted`:
898+
bcraw = Broadcast.broadcasted(identity, xs)
899+
bc = Broadcast.instantiate(bcraw)
900+
# If `Broadcasted` has `IndexLinear` style, it should hit the
901+
# `reduce` branch:
902+
@test IndexStyle(bc) == IndexLinear()
903+
@test reduce(paren, bc) == reduce(paren, xs)
904+
# If `Broadcasted` does not have `IndexLinear` style, it should
905+
# hit the `foldl` branch:
906+
@test IndexStyle(bcraw) == IndexCartesian()
907+
@test reduce(paren, bcraw) == foldl(paren, xs)
908+
end
909+
854910
# treat Pair as scalar:
855911
@test replace.(split("The quick brown fox jumps over the lazy dog"), r"[aeiou]"i => "_") ==
856912
["Th_", "q__ck", "br_wn", "f_x", "j_mps", "_v_r", "th_", "l_zy", "d_g"]

0 commit comments

Comments
 (0)