diff --git a/src/StaticArrays.jl b/src/StaticArrays.jl index 6b21ed46..fa520eb9 100644 --- a/src/StaticArrays.jl +++ b/src/StaticArrays.jl @@ -123,6 +123,7 @@ include("abstractarray.jl") include("indexing.jl") include("broadcast.jl") include("mapreduce.jl") +include("sort.jl") include("arraymath.jl") include("linalg.jl") include("matrix_multiply.jl") diff --git a/src/sort.jl b/src/sort.jl new file mode 100644 index 00000000..01f08837 --- /dev/null +++ b/src/sort.jl @@ -0,0 +1,75 @@ +import Base.Order: Ordering, Forward, ReverseOrdering, ord +import Base.Sort: Algorithm, lt, sort + + +struct BitonicSortAlg <: Algorithm end + +const BitonicSort = BitonicSortAlg() + + +# BitonicSort has non-optimal asymptotic behaviour, so we define a cutoff +# length. This also prevents compilation time to skyrocket for larger vectors. +defalg(a::StaticVector) = + isimmutable(a) && length(a) <= 20 ? BitonicSort : QuickSort + +@inline function sort(a::StaticVector; + alg::Algorithm = defalg(a), + lt = isless, + by = identity, + rev::Union{Bool,Nothing} = nothing, + order::Ordering = Forward) + length(a) <= 1 && return a + ordr = ord(lt, by, rev, order) + return _sort(a, alg, ordr) +end + +@inline _sort(a::StaticVector, alg, order) = + similar_type(a)(sort!(Base.copymutable(a); alg=alg, order=order)) + +@inline _sort(a::StaticVector, alg::BitonicSortAlg, order) = + similar_type(a)(_sort(Tuple(a), alg, order)) + +# Implementation loosely following +# https://www.inf.hs-flensburg.de/lang/algorithmen/sortieren/bitonic/oddn.htm +@generated function _sort(a::NTuple{N}, ::BitonicSortAlg, order) where N + function swap_expr(i, j, rev) + ai = Symbol('a', i) + aj = Symbol('a', j) + order = rev ? :revorder : :order + return :( ($ai, $aj) = lt($order, $ai, $aj) ? ($ai, $aj) : ($aj, $ai) ) + end + + function merge_exprs(idx, rev) + exprs = Expr[] + length(idx) == 1 && return exprs + + ci = 2^(ceil(Int, log2(length(idx))) - 1) + # TODO: generate simd code for these swaps + for i in first(idx):last(idx)-ci + push!(exprs, swap_expr(i, i+ci, rev)) + end + append!(exprs, merge_exprs(idx[1:ci], rev)) + append!(exprs, merge_exprs(idx[ci+1:end], rev)) + return exprs + end + + function sort_exprs(idx, rev=false) + exprs = Expr[] + length(idx) == 1 && return exprs + + append!(exprs, sort_exprs(idx[1:end÷2], !rev)) + append!(exprs, sort_exprs(idx[end÷2+1:end], rev)) + append!(exprs, merge_exprs(idx, rev)) + return exprs + end + + idx = 1:N + symlist = (Symbol('a', i) for i in idx) + return quote + @_inline_meta + revorder = Base.Order.ReverseOrdering(order) + ($(symlist...),) = a + ($(sort_exprs(idx)...);) + return ($(symlist...),) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 34b508fc..961c8a2a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -33,6 +33,7 @@ include("abstractarray.jl") include("indexing.jl") include("initializers.jl") Random.seed!(42); include("mapreduce.jl") +Random.seed!(42); include("sort.jl") Random.seed!(42); include("accumulate.jl") Random.seed!(42); include("arraymath.jl") include("broadcast.jl") diff --git a/test/sort.jl b/test/sort.jl new file mode 100644 index 00000000..01af14da --- /dev/null +++ b/test/sort.jl @@ -0,0 +1,22 @@ +using StaticArrays, Test + +@testset "sort" begin + + @testset "basics" for T in (Int, Float64) + for N in (0, 1, 2, 3, 10, 20, 30) + vs = rand(SVector{N,T}) + vm = MVector{N,T}(vs) + vref = sort(Vector(vs)) + + @test @inferred(sort(vs)) isa SVector + @test @inferred(sort(vs, alg=QuickSort)) isa SVector + @test @inferred(sort(vm)) isa MVector + @test vref == sort(vs) + @test vref == sort(vm) + + # @allocated seems broken since 1.4 + #N <= 20 && @test 0 == @allocated sort(vs) + end + end + +end