Skip to content

Commit

Permalink
Allow N-dimensional arrays in sorting rules
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th committed Aug 30, 2023
1 parent 79722bf commit a9e831b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/rulesets/Base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ function rrule(::typeof(partialsort), xs::AbstractVector, k::Union{Integer,Ordin
return ys, partialsort_pullback
end

function frule((_, ẋs), ::typeof(sort), xs::AbstractVector; kw...)
function frule((_, ẋs), ::typeof(sort), xs::AbstractArray; kw...)
inds = sortperm(xs; kw...)
return xs[inds], ẋs[inds]
end

function rrule(::typeof(sort), xs::AbstractVector; kwargs...)
function rrule(::typeof(sort), xs::AbstractArray; kwargs...)
inds = sortperm(xs; kwargs...)
ys = xs[inds]

Expand Down
9 changes: 9 additions & 0 deletions test/rulesets/Base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@
# rev
test_rrule(sort, a)
test_rrule(sort, a; fkwargs=(;rev=true))

a = rand(5, 4)
for dims in (1, 2)
# fwd
test_frule(sort, a; fkwargs=(;dims))
test_frule(sort, a; fkwargs=(;dims, rev=true))
# rev
test_rrule(sort, a; fkwargs=(;dims))
test_rrule(sort, a; fkwargs=(;dims, rev=true))
end
@testset "partialsort" begin
a = rand(10)
Expand Down

0 comments on commit a9e831b

Please sign in to comment.