Skip to content

Commit a9e831b

Browse files
committed
Allow N-dimensional arrays in sorting rules
1 parent 79722bf commit a9e831b

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

src/rulesets/Base/sort.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@ function rrule(::typeof(partialsort), xs::AbstractVector, k::Union{Integer,Ordin
2525
return ys, partialsort_pullback
2626
end
2727

28-
function frule((_, ẋs), ::typeof(sort), xs::AbstractVector; kw...)
28+
function frule((_, ẋs), ::typeof(sort), xs::AbstractArray; kw...)
2929
inds = sortperm(xs; kw...)
3030
return xs[inds], ẋs[inds]
3131
end
3232

33-
function rrule(::typeof(sort), xs::AbstractVector; kwargs...)
33+
function rrule(::typeof(sort), xs::AbstractArray; kwargs...)
3434
inds = sortperm(xs; kwargs...)
3535
ys = xs[inds]
3636

test/rulesets/Base/sort.jl

+9
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,15 @@
77
# rev
88
test_rrule(sort, a)
99
test_rrule(sort, a; fkwargs=(;rev=true))
10+
11+
a = rand(5, 4)
12+
for dims in (1, 2)
13+
# fwd
14+
test_frule(sort, a; fkwargs=(;dims))
15+
test_frule(sort, a; fkwargs=(;dims, rev=true))
16+
# rev
17+
test_rrule(sort, a; fkwargs=(;dims))
18+
test_rrule(sort, a; fkwargs=(;dims, rev=true))
1019
end
1120
@testset "partialsort" begin
1221
a = rand(10)

0 commit comments

Comments
 (0)