Skip to content

Make Broadcasted iterable and more indexable #26987

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 9, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 42 additions & 13 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,15 @@ end
Base.convert(::Type{Broadcasted{NewStyle}}, bc::Broadcasted{Style,Axes,F,Args}) where {NewStyle,Style,Axes,F,Args} =
Broadcasted{NewStyle,Axes,F,Args}(bc.f, bc.args, bc.axes)

Base.show(io::IO, bc::Broadcasted{Style}) where {Style} = print(io, Broadcasted, '{', Style, "}(", bc.f, ", ", bc.args, ')')
function Base.show(io::IO, bc::Broadcasted{Style}) where {Style}
print(io, Broadcasted)
# Only show the style parameter if we have a set of axes — representing an instantiated
# "outermost" Broadcasted. The styles of nested Broadcasteds represent an intermediate
# computation that is not relevant for dispatch, confusing, and just extra line noise.
bc.axes isa Tuple && print(io, '{', Style, '}')
print(io, '(', bc.f, ", ", bc.args, ')')
nothing
end

## Allocating the output container
"""
Expand Down Expand Up @@ -218,8 +226,6 @@ This should only be specialized for objects that do not define axes but want to
"""
broadcast_axes

### End of methods that users will typically have to specialize ###

@inline Base.axes(bc::Broadcasted) = _axes(bc, bc.axes)
_axes(::Broadcasted, axes::Tuple) = axes
@inline _axes(bc::Broadcasted, ::Nothing) = combine_axes(bc.args...)
Expand All @@ -239,19 +245,39 @@ _not_nested(t::Tuple) = _not_nested(tail(t))
_not_nested(::NestedTuple) = false
_not_nested(::Tuple{}) = true

@inline Base.eachindex(bc::Broadcasted) = _eachindex(axes(bc))
_eachindex(t::Tuple{Any}) = t[1]
_eachindex(t::Tuple) = CartesianIndices(t)

Base.ndims(::Broadcasted{<:Any,<:NTuple{N,Any}}) where {N} = N
Base.ndims(::Type{<:Broadcasted{<:Any,<:NTuple{N,Any}}}) where {N} = N

Base.length(bc::Broadcasted) = prod(map(length, axes(bc)))
Base.size(bc::Broadcasted) = _size(axes(bc))
_size(::Tuple{Vararg{Base.OneTo}}) = map(length, axes(bc))

Base.start(bc::Broadcasted) = (iter = eachindex(bc); (iter, start(iter)))
Base.@propagate_inbounds function Base.next(bc::Broadcasted, s)
iter, state = s
i, newstate = next(iter, state)
return (bc[i], (iter, newstate))
end
Base.done(bc::Broadcasted, s) = done(s[1], s[2])

Base.IteratorSize(::Type{<:Broadcasted{<:Any,<:NTuple{N,Base.OneTo}}}) where {N} = Base.HasShape{N}()
Base.IteratorEltype(::Type{<:Broadcasted}) = Base.EltypeUnknown()

## Instantiation fills in the "missing" fields in Broadcasted.
instantiate(x) = x

"""
Broadcast.instantiate(bc::Broadcasted)

Construct the axes and indexing helpers for the lazy Broadcasted object `bc`.
Construct and check the axes for the lazy Broadcasted object `bc`.

Custom `BroadcastStyle`s may override this default in cases where it is fast and easy
to compute the resulting `axes` and indexing helpers on-demand, leaving those fields
of the `Broadcasted` object empty (populated with `nothing`). If they do so, however,
they must provide their own `Base.axes(::Broadcasted{Style})` and
`Base.getindex(::Broadcasted{Style}, I::Union{Int,CartesianIndex})` methods as appropriate.
to compute and verify the resulting `axes` on-demand, leaving the `axis` field
of the `Broadcasted` object empty (populated with `nothing`).
"""
@inline function instantiate(bc::Broadcasted{Style}) where {Style}
if bc.axes isa Nothing # Not done via dispatch to make it easier to extend instantiate(::Broadcasted{Style})
Expand Down Expand Up @@ -481,6 +507,7 @@ Base.@propagate_inbounds _newindex(ax::Tuple{}, I::Tuple{}) = ()
# If dot-broadcasting were already defined, this would be `ifelse.(keep, I, Idefault)`.
@inline newindex(I::CartesianIndex, keep, Idefault) = CartesianIndex(_newindex(I.I, keep, Idefault))
@inline newindex(i::Int, keep::Tuple{Bool}, idefault) = ifelse(keep[1], i, idefault[1])
@inline newindex(i::Int, keep::Tuple{}, idefault) = CartesianIndex(())
@inline _newindex(I, keep, Idefault) =
(ifelse(keep[1], I[1], Idefault[1]), _newindex(tail(I), tail(keep), tail(Idefault))...)
@inline _newindex(I, keep::Tuple{}, Idefault) = () # truncate if keep is shorter than I
Expand All @@ -496,12 +523,14 @@ Base.@propagate_inbounds _newindex(ax::Tuple{}, I::Tuple{}) = ()
(length(ind1)!=1, keep...), (first(ind1), Idefault...)
end

@inline function Base.getindex(bc::Broadcasted, I)
@inline function Base.getindex(bc::Broadcasted, I::Union{Int,CartesianIndex})
@boundscheck checkbounds(bc, I)
@inbounds _broadcast_getindex(bc, I)
end
Base.@propagate_inbounds Base.getindex(bc::Broadcasted, i1::Int, i2::Int, I::Int...) = bc[CartesianIndex((i1, i2, I...))]
Base.@propagate_inbounds Base.getindex(bc::Broadcasted) = bc[CartesianIndex(())]

@inline Base.checkbounds(bc::Broadcasted, I) =
@inline Base.checkbounds(bc::Broadcasted, I::Union{Int,CartesianIndex}) =
Base.checkbounds_indices(Bool, axes(bc), (I,)) || Base.throw_boundserror(bc, (I,))


Expand Down Expand Up @@ -739,7 +768,7 @@ const NonleafHandlingStyles = Union{DefaultArrayStyle,ArrayConflict}
# value to determine the starting output eltype; copyto_nonleaf!
# will widen `dest` as needed to accommodate later values.
bc′ = preprocess(nothing, bc)
iter = CartesianIndices(axes(bc′))
iter = eachindex(bc′)
state = start(iter)
if done(iter, state)
# if empty, take the ElType at face value
Expand Down Expand Up @@ -807,7 +836,7 @@ preprocess_args(dest, args::Tuple{}) = ()
end
end
bc′ = preprocess(dest, bc)
@simd for I in CartesianIndices(axes(bc′))
@simd for I in eachindex(bc′)
@inbounds dest[I] = bc′[I]
end
return dest
Expand All @@ -822,7 +851,7 @@ end
destc = dest.chunks
ind = cind = 1
bc′ = preprocess(dest, bc)
@simd for I in CartesianIndices(axes(bc′))
@simd for I in eachindex(bc′)
@inbounds tmp[ind] = bc′[I]
ind += 1
if ind > bitcache_size
Expand Down
23 changes: 23 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -726,3 +726,26 @@ let f(args...) = *(args...)
@test f.(x..., y, z...) == broadcast(f, x..., y, z...) == 120
@test f.(x..., f.(x..., y, z...), y, z...) == broadcast(f, x..., broadcast(f, x..., y, z...), y, z...) == 120*120
end

# Broadcasted iterable/indexable APIs
let
bc = Broadcast.instantiate(Broadcast.broadcasted(+, zeros(5), 5))
@test eachindex(bc) === Base.OneTo(5)
@test length(bc) === 5
@test ndims(bc) === 1
@test ndims(typeof(bc)) === 1
@test bc[1] === bc[CartesianIndex((1,))] === 5.0
@test copy(bc) == [v for v in bc] == collect(bc)
@test eltype(copy(bc)) == eltype([v for v in bc]) == eltype(collect(bc))
@test ndims(copy(bc)) == ndims([v for v in bc]) == ndims(collect(bc)) == ndims(bc)

bc = Broadcast.instantiate(Broadcast.broadcasted(+, zeros(5), 5*ones(1, 4)))
@test eachindex(bc) === CartesianIndices((Base.OneTo(5), Base.OneTo(4)))
@test length(bc) === 20
@test ndims(bc) === 2
@test ndims(typeof(bc)) === 2
@test bc[1,1] == bc[CartesianIndex((1,1))] === 5.0
@test copy(bc) == [v for v in bc] == collect(bc)
@test eltype(copy(bc)) == eltype([v for v in bc]) == eltype(collect(bc))
@test ndims(copy(bc)) == ndims([v for v in bc]) == ndims(collect(bc)) == ndims(bc)
end