diff --git a/src/rulesets/Base/sort.jl b/src/rulesets/Base/sort.jl index 0805da91f..2ebccbd84 100644 --- a/src/rulesets/Base/sort.jl +++ b/src/rulesets/Base/sort.jl @@ -25,12 +25,12 @@ function rrule(::typeof(partialsort), xs::AbstractVector, k::Union{Integer,Ordin return ys, partialsort_pullback end -function frule((_, ẋs), ::typeof(sort), xs::AbstractVector; kw...) +function frule((_, ẋs), ::typeof(sort), xs::AbstractArray; kw...) inds = sortperm(xs; kw...) return xs[inds], ẋs[inds] end -function rrule(::typeof(sort), xs::AbstractVector; kwargs...) +function rrule(::typeof(sort), xs::AbstractArray; kwargs...) inds = sortperm(xs; kwargs...) ys = xs[inds] diff --git a/test/rulesets/Base/sort.jl b/test/rulesets/Base/sort.jl index 052045d1e..d06067bd2 100644 --- a/test/rulesets/Base/sort.jl +++ b/test/rulesets/Base/sort.jl @@ -7,6 +7,18 @@ # rev test_rrule(sort, a) test_rrule(sort, a; fkwargs=(;rev=true)) + + if VERSION ≥ v"1.9" + a = rand(5, 4) + for dims in (1, 2) + # fwd + test_frule(sort, a; fkwargs=(;dims)) + test_frule(sort, a; fkwargs=(;dims, rev=true)) + # rev + test_rrule(sort, a; fkwargs=(;dims)) + test_rrule(sort, a; fkwargs=(;dims, rev=true)) + end + end end @testset "partialsort" begin a = rand(10)