diff --git a/make/modules/java.base/Lib.gmk b/make/modules/java.base/Lib.gmk index cc25cb8d0ae54..57f98707cb7bc 100644 --- a/make/modules/java.base/Lib.gmk +++ b/make/modules/java.base/Lib.gmk @@ -227,10 +227,30 @@ ifeq ($(ENABLE_FALLBACK_LINKER), true) NAME := fallbackLinker, \ CFLAGS := $(CFLAGS_JDKLIB) $(LIBFFI_CFLAGS), \ LDFLAGS := $(LDFLAGS_JDKLIB) \ - $(call SET_SHARED_LIBRARY_ORIGIN), \ + $(call SET_SHARED_LIBRARY_ORIGIN), \ LIBS := $(LIBFFI_LIBS), \ LIBS_windows := $(LIBFFI_LIBS) ws2_32.lib, \ )) TARGETS += $(BUILD_LIBFALLBACKLINKER) endif + +################################################################################ + +ifeq ($(call isTargetOs, linux)+$(call isTargetCpu, x86_64)+$(INCLUDE_COMPILER2)+$(filter $(TOOLCHAIN_TYPE), gcc), true+true+true+gcc) + $(eval $(call SetupJdkLibrary, BUILD_LIB_SIMD_SORT, \ + NAME := simdsort, \ + TOOLCHAIN := TOOLCHAIN_LINK_CXX, \ + OPTIMIZATION := HIGH, \ + CFLAGS := $(CFLAGS_JDKLIB), \ + CXXFLAGS := $(CXXFLAGS_JDKLIB), \ + LDFLAGS := $(LDFLAGS_JDKLIB) \ + $(call SET_SHARED_LIBRARY_ORIGIN), \ + LIBS := $(LIBCXX), \ + LIBS_linux := -lc -lm -ldl, \ + )) + + TARGETS += $(BUILD_LIB_SIMD_SORT) +endif + +################################################################################ diff --git a/src/hotspot/cpu/x86/stubGenerator_x86_64.cpp b/src/hotspot/cpu/x86/stubGenerator_x86_64.cpp index c6178836df2a8..79ebef8b58113 100644 --- a/src/hotspot/cpu/x86/stubGenerator_x86_64.cpp +++ b/src/hotspot/cpu/x86/stubGenerator_x86_64.cpp @@ -4172,6 +4172,26 @@ void StubGenerator::generate_compiler_stubs() { = CAST_FROM_FN_PTR(address, SharedRuntime::montgomery_square); } + // Load x86_64_sort library on supported hardware to enable avx512 sort and partition intrinsics + if (UseAVX > 2 && VM_Version::supports_avx512dq()) { + void *libsimdsort = nullptr; + char ebuf_[1024]; + char dll_name_simd_sort[JVM_MAXPATHLEN]; + if (os::dll_locate_lib(dll_name_simd_sort, sizeof(dll_name_simd_sort), Arguments::get_dll_dir(), "simdsort")) { + libsimdsort = os::dll_load(dll_name_simd_sort, ebuf_, sizeof ebuf_); + } + // Get addresses for avx512 sort and partition routines + if (libsimdsort != nullptr) { + log_info(library)("Loaded library %s, handle " INTPTR_FORMAT, JNI_LIB_PREFIX "simdsort" JNI_LIB_SUFFIX, p2i(libsimdsort)); + + snprintf(ebuf_, sizeof(ebuf_), "avx512_sort"); + StubRoutines::_array_sort = (address)os::dll_lookup(libsimdsort, ebuf_); + + snprintf(ebuf_, sizeof(ebuf_), "avx512_partition"); + StubRoutines::_array_partition = (address)os::dll_lookup(libsimdsort, ebuf_); + } + } + // Get svml stub routine addresses void *libjsvml = nullptr; char ebuf[1024]; diff --git a/src/hotspot/share/classfile/vmIntrinsics.hpp b/src/hotspot/share/classfile/vmIntrinsics.hpp index 1f2d601a02bcd..66b8a43640728 100644 --- a/src/hotspot/share/classfile/vmIntrinsics.hpp +++ b/src/hotspot/share/classfile/vmIntrinsics.hpp @@ -341,6 +341,14 @@ class methodHandle; do_name( copyOf_name, "copyOf") \ do_signature(copyOf_signature, "([Ljava/lang/Object;ILjava/lang/Class;)[Ljava/lang/Object;") \ \ + do_intrinsic(_arraySort, java_util_DualPivotQuicksort, arraySort_name, arraySort_signature, F_S) \ + do_name( arraySort_name, "sort") \ + do_signature(arraySort_signature, "(Ljava/lang/Class;Ljava/lang/Object;JIILjava/util/DualPivotQuicksort$SortOperation;)V") \ + \ + do_intrinsic(_arrayPartition, java_util_DualPivotQuicksort, arrayPartition_name, arrayPartition_signature, F_S) \ + do_name( arrayPartition_name, "partition") \ + do_signature(arrayPartition_signature, "(Ljava/lang/Class;Ljava/lang/Object;JIIIILjava/util/DualPivotQuicksort$PartitionOperation;)[I") \ + \ do_intrinsic(_copyOfRange, java_util_Arrays, copyOfRange_name, copyOfRange_signature, F_S) \ do_name( copyOfRange_name, "copyOfRange") \ do_signature(copyOfRange_signature, "([Ljava/lang/Object;IILjava/lang/Class;)[Ljava/lang/Object;") \ diff --git a/src/hotspot/share/classfile/vmSymbols.hpp b/src/hotspot/share/classfile/vmSymbols.hpp index 1e5bbb60ca75f..b450c1250539c 100644 --- a/src/hotspot/share/classfile/vmSymbols.hpp +++ b/src/hotspot/share/classfile/vmSymbols.hpp @@ -145,6 +145,7 @@ class SerializeClosure; template(java_util_Vector, "java/util/Vector") \ template(java_util_AbstractList, "java/util/AbstractList") \ template(java_util_Hashtable, "java/util/Hashtable") \ + template(java_util_DualPivotQuicksort, "java/util/DualPivotQuicksort") \ template(java_lang_Compiler, "java/lang/Compiler") \ template(jdk_internal_misc_Signal, "jdk/internal/misc/Signal") \ template(jdk_internal_util_Preconditions, "jdk/internal/util/Preconditions") \ diff --git a/src/hotspot/share/gc/shenandoah/c2/shenandoahSupport.cpp b/src/hotspot/share/gc/shenandoah/c2/shenandoahSupport.cpp index 21e6a068de1b7..8c7ee3a3aedc1 100644 --- a/src/hotspot/share/gc/shenandoah/c2/shenandoahSupport.cpp +++ b/src/hotspot/share/gc/shenandoah/c2/shenandoahSupport.cpp @@ -387,6 +387,12 @@ void ShenandoahBarrierC2Support::verify(RootNode* root) { verify_type t; } args[6]; } calls[] = { + "array_partition_stub", + { { TypeFunc::Parms, ShenandoahStore }, { TypeFunc::Parms+4, ShenandoahStore }, { -1, ShenandoahNone }, + { -1, ShenandoahNone }, { -1, ShenandoahNone }, { -1, ShenandoahNone } }, + "arraysort_stub", + { { TypeFunc::Parms, ShenandoahStore }, { -1, ShenandoahNone }, { -1, ShenandoahNone }, + { -1, ShenandoahNone}, { -1, ShenandoahNone}, { -1, ShenandoahNone} }, "aescrypt_encryptBlock", { { TypeFunc::Parms, ShenandoahLoad }, { TypeFunc::Parms+1, ShenandoahStore }, { TypeFunc::Parms+2, ShenandoahLoad }, { -1, ShenandoahNone}, { -1, ShenandoahNone}, { -1, ShenandoahNone} }, diff --git a/src/hotspot/share/jvmci/vmStructs_jvmci.cpp b/src/hotspot/share/jvmci/vmStructs_jvmci.cpp index 4a6440d1699ca..a6bb68ce13ce5 100644 --- a/src/hotspot/share/jvmci/vmStructs_jvmci.cpp +++ b/src/hotspot/share/jvmci/vmStructs_jvmci.cpp @@ -331,6 +331,8 @@ static_field(StubRoutines, _checkcast_arraycopy_uninit, address) \ static_field(StubRoutines, _unsafe_arraycopy, address) \ static_field(StubRoutines, _generic_arraycopy, address) \ + static_field(StubRoutines, _array_sort, address) \ + static_field(StubRoutines, _array_partition, address) \ \ static_field(StubRoutines, _aescrypt_encryptBlock, address) \ static_field(StubRoutines, _aescrypt_decryptBlock, address) \ diff --git a/src/hotspot/share/opto/c2compiler.cpp b/src/hotspot/share/opto/c2compiler.cpp index a9f09cfb80b13..e0ec2e34a0a69 100644 --- a/src/hotspot/share/opto/c2compiler.cpp +++ b/src/hotspot/share/opto/c2compiler.cpp @@ -614,6 +614,8 @@ bool C2Compiler::is_intrinsic_supported(vmIntrinsics::ID id) { case vmIntrinsics::_min_strict: case vmIntrinsics::_max_strict: case vmIntrinsics::_arraycopy: + case vmIntrinsics::_arraySort: + case vmIntrinsics::_arrayPartition: case vmIntrinsics::_indexOfL: case vmIntrinsics::_indexOfU: case vmIntrinsics::_indexOfUL: diff --git a/src/hotspot/share/opto/escape.cpp b/src/hotspot/share/opto/escape.cpp index 0570f043dbff9..c0da4b4a120ac 100644 --- a/src/hotspot/share/opto/escape.cpp +++ b/src/hotspot/share/opto/escape.cpp @@ -1575,6 +1575,8 @@ void ConnectionGraph::process_call_arguments(CallNode *call) { strcmp(call->as_CallLeaf()->_name, "bigIntegerRightShiftWorker") == 0 || strcmp(call->as_CallLeaf()->_name, "bigIntegerLeftShiftWorker") == 0 || strcmp(call->as_CallLeaf()->_name, "vectorizedMismatch") == 0 || + strcmp(call->as_CallLeaf()->_name, "arraysort_stub") == 0 || + strcmp(call->as_CallLeaf()->_name, "array_partition_stub") == 0 || strcmp(call->as_CallLeaf()->_name, "get_class_id_intrinsic") == 0) ))) { call->dump(); diff --git a/src/hotspot/share/opto/library_call.cpp b/src/hotspot/share/opto/library_call.cpp index 269ac1202fb0d..4a9d7fb161667 100644 --- a/src/hotspot/share/opto/library_call.cpp +++ b/src/hotspot/share/opto/library_call.cpp @@ -293,6 +293,9 @@ bool LibraryCallKit::try_to_inline(int predicate) { case vmIntrinsics::_arraycopy: return inline_arraycopy(); + case vmIntrinsics::_arraySort: return inline_array_sort(); + case vmIntrinsics::_arrayPartition: return inline_array_partition(); + case vmIntrinsics::_compareToL: return inline_string_compareTo(StrIntrinsicNode::LL); case vmIntrinsics::_compareToU: return inline_string_compareTo(StrIntrinsicNode::UU); case vmIntrinsics::_compareToLU: return inline_string_compareTo(StrIntrinsicNode::LU); @@ -5361,6 +5364,101 @@ void LibraryCallKit::create_new_uncommon_trap(CallStaticJavaNode* uncommon_trap_ uncommon_trap_call->set_req(0, top()); // not used anymore, kill it } +//------------------------------inline_array_partition----------------------- +bool LibraryCallKit::inline_array_partition() { + + const char *stubName = "array_partition_stub"; + + Node* elementType = null_check(argument(0)); + Node* obj = argument(1); + Node* offset = argument(2); + Node* fromIndex = argument(4); + Node* toIndex = argument(5); + Node* indexPivot1 = argument(6); + Node* indexPivot2 = argument(7); + + const TypeInstPtr* elem_klass = gvn().type(elementType)->isa_instptr(); + ciType* elem_type = elem_klass->const_oop()->as_instance()->java_mirror_type(); + BasicType bt = elem_type->basic_type(); + address stubAddr = nullptr; + stubAddr = StubRoutines::select_array_partition_function(); + // stub not loaded + if (stubAddr == nullptr) { + return false; + } + // get the address of the array + const TypeAryPtr* obj_t = _gvn.type(obj)->isa_aryptr(); + if (obj_t == nullptr || obj_t->elem() == Type::BOTTOM ) { + return false; // failed input validation + } + Node* obj_adr = make_unsafe_address(obj, offset); + + // create the pivotIndices array of type int and size = 2 + Node* size = intcon(2); + Node* klass_node = makecon(TypeKlassPtr::make(ciTypeArrayKlass::make(T_INT))); + Node* pivotIndices = new_array(klass_node, size, 0); // no arguments to push + AllocateArrayNode* alloc = tightly_coupled_allocation(pivotIndices); + guarantee(alloc != nullptr, "created above"); + Node* pivotIndices_adr = basic_plus_adr(pivotIndices, arrayOopDesc::base_offset_in_bytes(T_INT)); + + // pass the basic type enum to the stub + Node* elemType = intcon(bt); + + // Call the stub + make_runtime_call(RC_LEAF|RC_NO_FP, OptoRuntime::array_partition_Type(), + stubAddr, stubName, TypePtr::BOTTOM, + obj_adr, elemType, fromIndex, toIndex, pivotIndices_adr, + indexPivot1, indexPivot2); + + if (!stopped()) { + set_result(pivotIndices); + } + + return true; +} + + +//------------------------------inline_array_sort----------------------- +bool LibraryCallKit::inline_array_sort() { + + const char *stubName; + stubName = "arraysort_stub"; + + Node* elementType = null_check(argument(0)); + Node* obj = argument(1); + Node* offset = argument(2); + Node* fromIndex = argument(4); + Node* toIndex = argument(5); + + const TypeInstPtr* elem_klass = gvn().type(elementType)->isa_instptr(); + ciType* elem_type = elem_klass->const_oop()->as_instance()->java_mirror_type(); + BasicType bt = elem_type->basic_type(); + address stubAddr = nullptr; + stubAddr = StubRoutines::select_arraysort_function(); + //stub not loaded + if (stubAddr == nullptr) { + return false; + } + + // get address of the array + const TypeAryPtr* obj_t = _gvn.type(obj)->isa_aryptr(); + if (obj_t == nullptr || obj_t->elem() == Type::BOTTOM ) { + return false; // failed input validation + } + Node* obj_adr = make_unsafe_address(obj, offset); + + // pass the basic type enum to the stub + Node* elemType = intcon(bt); + + // Call the stub. + make_runtime_call(RC_LEAF|RC_NO_FP, OptoRuntime::array_sort_Type(), + stubAddr, stubName, TypePtr::BOTTOM, + obj_adr, elemType, fromIndex, toIndex); + + return true; +} + + //------------------------------inline_arraycopy----------------------- // public static native void java.lang.System.arraycopy(Object src, int srcPos, // Object dest, int destPos, diff --git a/src/hotspot/share/opto/library_call.hpp b/src/hotspot/share/opto/library_call.hpp index f714625a4df47..55d1dc78f1fd5 100644 --- a/src/hotspot/share/opto/library_call.hpp +++ b/src/hotspot/share/opto/library_call.hpp @@ -277,7 +277,8 @@ class LibraryCallKit : public GraphKit { JVMState* arraycopy_restore_alloc_state(AllocateArrayNode* alloc, int& saved_reexecute_sp); void arraycopy_move_allocation_here(AllocateArrayNode* alloc, Node* dest, JVMState* saved_jvms_before_guards, int saved_reexecute_sp, uint new_idx); - + bool inline_array_sort(); + bool inline_array_partition(); typedef enum { LS_get_add, LS_get_set, LS_cmp_swap, LS_cmp_swap_weak, LS_cmp_exchange } LoadStoreKind; bool inline_unsafe_load_store(BasicType type, LoadStoreKind kind, AccessKind access_kind); bool inline_unsafe_fence(vmIntrinsics::ID id); diff --git a/src/hotspot/share/opto/runtime.cpp b/src/hotspot/share/opto/runtime.cpp index bb79da3262245..8f3367ce0003c 100644 --- a/src/hotspot/share/opto/runtime.cpp +++ b/src/hotspot/share/opto/runtime.cpp @@ -857,6 +857,49 @@ const TypeFunc* OptoRuntime::array_fill_Type() { return TypeFunc::make(domain, range); } +const TypeFunc* OptoRuntime::array_partition_Type() { + // create input type (domain) + int num_args = 7; + int argcnt = num_args; + const Type** fields = TypeTuple::fields(argcnt); + int argp = TypeFunc::Parms; + fields[argp++] = TypePtr::NOTNULL; // array + fields[argp++] = TypeInt::INT; // element type + fields[argp++] = TypeInt::INT; // low + fields[argp++] = TypeInt::INT; // end + fields[argp++] = TypePtr::NOTNULL; // pivot_indices (int array) + fields[argp++] = TypeInt::INT; // indexPivot1 + fields[argp++] = TypeInt::INT; // indexPivot2 + assert(argp == TypeFunc::Parms+argcnt, "correct decoding"); + const TypeTuple* domain = TypeTuple::make(TypeFunc::Parms+argcnt, fields); + + // no result type needed + fields = TypeTuple::fields(1); + fields[TypeFunc::Parms+0] = nullptr; // void + const TypeTuple* range = TypeTuple::make(TypeFunc::Parms, fields); + return TypeFunc::make(domain, range); +} + +const TypeFunc* OptoRuntime::array_sort_Type() { + // create input type (domain) + int num_args = 4; + int argcnt = num_args; + const Type** fields = TypeTuple::fields(argcnt); + int argp = TypeFunc::Parms; + fields[argp++] = TypePtr::NOTNULL; // array + fields[argp++] = TypeInt::INT; // element type + fields[argp++] = TypeInt::INT; // fromIndex + fields[argp++] = TypeInt::INT; // toIndex + assert(argp == TypeFunc::Parms+argcnt, "correct decoding"); + const TypeTuple* domain = TypeTuple::make(TypeFunc::Parms+argcnt, fields); + + // no result type needed + fields = TypeTuple::fields(1); + fields[TypeFunc::Parms+0] = nullptr; // void + const TypeTuple* range = TypeTuple::make(TypeFunc::Parms, fields); + return TypeFunc::make(domain, range); +} + // for aescrypt encrypt/decrypt operations, just three pointers returning void (length is constant) const TypeFunc* OptoRuntime::aescrypt_block_Type() { // create input type (domain) diff --git a/src/hotspot/share/opto/runtime.hpp b/src/hotspot/share/opto/runtime.hpp index cd13c14148d71..b85542423e848 100644 --- a/src/hotspot/share/opto/runtime.hpp +++ b/src/hotspot/share/opto/runtime.hpp @@ -268,6 +268,8 @@ class OptoRuntime : public AllStatic { static const TypeFunc* array_fill_Type(); + static const TypeFunc* array_sort_Type(); + static const TypeFunc* array_partition_Type(); static const TypeFunc* aescrypt_block_Type(); static const TypeFunc* cipherBlockChaining_aescrypt_Type(); static const TypeFunc* electronicCodeBook_aescrypt_Type(); diff --git a/src/hotspot/share/runtime/stubRoutines.cpp b/src/hotspot/share/runtime/stubRoutines.cpp index c2c57c8f12374..bea2a934bc603 100644 --- a/src/hotspot/share/runtime/stubRoutines.cpp +++ b/src/hotspot/share/runtime/stubRoutines.cpp @@ -176,6 +176,9 @@ address StubRoutines::_hf2f = nullptr; address StubRoutines::_vector_f_math[VectorSupport::NUM_VEC_SIZES][VectorSupport::NUM_SVML_OP] = {{nullptr}, {nullptr}}; address StubRoutines::_vector_d_math[VectorSupport::NUM_VEC_SIZES][VectorSupport::NUM_SVML_OP] = {{nullptr}, {nullptr}}; +address StubRoutines::_array_sort = nullptr; +address StubRoutines::_array_partition = nullptr; + address StubRoutines::_cont_thaw = nullptr; address StubRoutines::_cont_returnBarrier = nullptr; address StubRoutines::_cont_returnBarrierExc = nullptr; diff --git a/src/hotspot/share/runtime/stubRoutines.hpp b/src/hotspot/share/runtime/stubRoutines.hpp index 2269cee2a53b1..d62c3913f2533 100644 --- a/src/hotspot/share/runtime/stubRoutines.hpp +++ b/src/hotspot/share/runtime/stubRoutines.hpp @@ -153,6 +153,8 @@ class StubRoutines: AllStatic { static BufferBlob* _compiler_stubs_code; // code buffer for C2 intrinsics static BufferBlob* _final_stubs_code; // code buffer for all other routines + static address _array_sort; + static address _array_partition; // Leaf routines which implement arraycopy and their addresses // arraycopy operands aligned on element type boundary static address _jbyte_arraycopy; @@ -375,6 +377,8 @@ class StubRoutines: AllStatic { static UnsafeArrayCopyStub UnsafeArrayCopy_stub() { return CAST_TO_FN_PTR(UnsafeArrayCopyStub, _unsafe_arraycopy); } static address generic_arraycopy() { return _generic_arraycopy; } + static address select_arraysort_function() { return _array_sort; } + static address select_array_partition_function() { return _array_partition; } static address jbyte_fill() { return _jbyte_fill; } static address jshort_fill() { return _jshort_fill; } diff --git a/src/java.base/linux/native/libsimdsort/avx512-32bit-qsort.hpp b/src/java.base/linux/native/libsimdsort/avx512-32bit-qsort.hpp new file mode 100644 index 0000000000000..4fbe9b97450c6 --- /dev/null +++ b/src/java.base/linux/native/libsimdsort/avx512-32bit-qsort.hpp @@ -0,0 +1,441 @@ +/* + * Copyright (c) 2021, 2023, Intel Corporation. All rights reserved. + * Copyright (c) 2021 Serge Sans Paille. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + * + */ + +// This implementation is based on x86-simd-sort(https://github.com/intel/x86-simd-sort) + +#ifndef AVX512_QSORT_32BIT +#define AVX512_QSORT_32BIT + +#include "avx512-common-qsort.h" + +/* + * Constants used in sorting 16 elements in a ZMM registers. Based on Bitonic + * sorting network (see + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) + */ +#define NETWORK_32BIT_1 14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1 +#define NETWORK_32BIT_2 12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3 +#define NETWORK_32BIT_3 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7 +#define NETWORK_32BIT_4 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2 +#define NETWORK_32BIT_5 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 +#define NETWORK_32BIT_6 11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4 +#define NETWORK_32BIT_7 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8 + +template <> +struct zmm_vector { + using type_t = int32_t; + using zmm_t = __m512i; + using ymm_t = __m256i; + using opmask_t = __mmask16; + static const uint8_t numlanes = 16; + + static type_t type_max() { return X86_SIMD_SORT_MAX_INT32; } + static type_t type_min() { return X86_SIMD_SORT_MIN_INT32; } + static zmm_t zmm_max() { return _mm512_set1_epi32(type_max()); } + + static opmask_t knot_opmask(opmask_t x) { return _mm512_knot(x); } + static opmask_t ge(zmm_t x, zmm_t y) { + return _mm512_cmp_epi32_mask(x, y, _MM_CMPINT_NLT); + } + static opmask_t gt(zmm_t x, zmm_t y) { + return _mm512_cmp_epi32_mask(x, y, _MM_CMPINT_GT); + } + template + static ymm_t i64gather(__m512i index, void const *base) { + return _mm512_i64gather_epi32(index, base, scale); + } + static zmm_t merge(ymm_t y1, ymm_t y2) { + zmm_t z1 = _mm512_castsi256_si512(y1); + return _mm512_inserti32x8(z1, y2, 1); + } + static zmm_t loadu(void const *mem) { return _mm512_loadu_si512(mem); } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) { + return _mm512_mask_compressstoreu_epi32(mem, mask, x); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) { + return _mm512_mask_loadu_epi32(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) { + return _mm512_mask_mov_epi32(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) { + return _mm512_mask_storeu_epi32(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) { return _mm512_min_epi32(x, y); } + static zmm_t max(zmm_t x, zmm_t y) { return _mm512_max_epi32(x, y); } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) { + return _mm512_permutexvar_epi32(idx, zmm); + } + static type_t reducemax(zmm_t v) { return _mm512_reduce_max_epi32(v); } + static type_t reducemin(zmm_t v) { return _mm512_reduce_min_epi32(v); } + static zmm_t set1(type_t v) { return _mm512_set1_epi32(v); } + template + static zmm_t shuffle(zmm_t zmm) { + return _mm512_shuffle_epi32(zmm, (_MM_PERM_ENUM)mask); + } + static void storeu(void *mem, zmm_t x) { + return _mm512_storeu_si512(mem, x); + } + + static ymm_t max(ymm_t x, ymm_t y) { return _mm256_max_epi32(x, y); } + static ymm_t min(ymm_t x, ymm_t y) { return _mm256_min_epi32(x, y); } +}; +template <> +struct zmm_vector { + using type_t = float; + using zmm_t = __m512; + using ymm_t = __m256; + using opmask_t = __mmask16; + static const uint8_t numlanes = 16; + + static type_t type_max() { return X86_SIMD_SORT_INFINITYF; } + static type_t type_min() { return -X86_SIMD_SORT_INFINITYF; } + static zmm_t zmm_max() { return _mm512_set1_ps(type_max()); } + + static opmask_t knot_opmask(opmask_t x) { return _mm512_knot(x); } + static opmask_t ge(zmm_t x, zmm_t y) { + return _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ); + } + static opmask_t gt(zmm_t x, zmm_t y) { + return _mm512_cmp_ps_mask(x, y, _CMP_GT_OQ); + } + template + static ymm_t i64gather(__m512i index, void const *base) { + return _mm512_i64gather_ps(index, base, scale); + } + static zmm_t merge(ymm_t y1, ymm_t y2) { + zmm_t z1 = _mm512_castsi512_ps( + _mm512_castsi256_si512(_mm256_castps_si256(y1))); + return _mm512_insertf32x8(z1, y2, 1); + } + static zmm_t loadu(void const *mem) { return _mm512_loadu_ps(mem); } + static zmm_t max(zmm_t x, zmm_t y) { return _mm512_max_ps(x, y); } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) { + return _mm512_mask_compressstoreu_ps(mem, mask, x); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) { + return _mm512_mask_loadu_ps(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) { + return _mm512_mask_mov_ps(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) { + return _mm512_mask_storeu_ps(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) { return _mm512_min_ps(x, y); } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) { + return _mm512_permutexvar_ps(idx, zmm); + } + static type_t reducemax(zmm_t v) { return _mm512_reduce_max_ps(v); } + static type_t reducemin(zmm_t v) { return _mm512_reduce_min_ps(v); } + static zmm_t set1(type_t v) { return _mm512_set1_ps(v); } + template + static zmm_t shuffle(zmm_t zmm) { + return _mm512_shuffle_ps(zmm, zmm, (_MM_PERM_ENUM)mask); + } + static void storeu(void *mem, zmm_t x) { return _mm512_storeu_ps(mem, x); } + + static ymm_t max(ymm_t x, ymm_t y) { return _mm256_max_ps(x, y); } + static ymm_t min(ymm_t x, ymm_t y) { return _mm256_min_ps(x, y); } +}; + +/* + * Assumes zmm is random and performs a full sorting network defined in + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg + */ +template +X86_SIMD_SORT_INLINE zmm_t sort_zmm_32bit(zmm_t zmm) { + zmm = cmp_merge( + zmm, vtype::template shuffle(zmm), 0xAAAA); + zmm = cmp_merge( + zmm, vtype::template shuffle(zmm), 0xCCCC); + zmm = cmp_merge( + zmm, vtype::template shuffle(zmm), 0xAAAA); + zmm = cmp_merge( + zmm, vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_3), zmm), + 0xF0F0); + zmm = cmp_merge( + zmm, vtype::template shuffle(zmm), 0xCCCC); + zmm = cmp_merge( + zmm, vtype::template shuffle(zmm), 0xAAAA); + zmm = cmp_merge( + zmm, vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), zmm), + 0xFF00); + zmm = cmp_merge( + zmm, vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_6), zmm), + 0xF0F0); + zmm = cmp_merge( + zmm, vtype::template shuffle(zmm), 0xCCCC); + zmm = cmp_merge( + zmm, vtype::template shuffle(zmm), 0xAAAA); + return zmm; +} + +// Assumes zmm is bitonic and performs a recursive half cleaner +template +X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_32bit(zmm_t zmm) { + // 1) half_cleaner[16]: compare 1-9, 2-10, 3-11 etc .. + zmm = cmp_merge( + zmm, vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_7), zmm), + 0xFF00); + // 2) half_cleaner[8]: compare 1-5, 2-6, 3-7 etc .. + zmm = cmp_merge( + zmm, vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_6), zmm), + 0xF0F0); + // 3) half_cleaner[4] + zmm = cmp_merge( + zmm, vtype::template shuffle(zmm), 0xCCCC); + // 3) half_cleaner[1] + zmm = cmp_merge( + zmm, vtype::template shuffle(zmm), 0xAAAA); + return zmm; +} + +// Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner +template +X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_32bit(zmm_t *zmm1, + zmm_t *zmm2) { + // 1) First step of a merging network: coex of zmm1 and zmm2 reversed + *zmm2 = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), *zmm2); + zmm_t zmm3 = vtype::min(*zmm1, *zmm2); + zmm_t zmm4 = vtype::max(*zmm1, *zmm2); + // 2) Recursive half cleaner for each + *zmm1 = bitonic_merge_zmm_32bit(zmm3); + *zmm2 = bitonic_merge_zmm_32bit(zmm4); +} + +// Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive +// half cleaner +template +X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_32bit(zmm_t *zmm) { + zmm_t zmm2r = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), zmm[2]); + zmm_t zmm3r = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), zmm[3]); + zmm_t zmm_t1 = vtype::min(zmm[0], zmm3r); + zmm_t zmm_t2 = vtype::min(zmm[1], zmm2r); + zmm_t zmm_t3 = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), + vtype::max(zmm[1], zmm2r)); + zmm_t zmm_t4 = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), + vtype::max(zmm[0], zmm3r)); + zmm_t zmm0 = vtype::min(zmm_t1, zmm_t2); + zmm_t zmm1 = vtype::max(zmm_t1, zmm_t2); + zmm_t zmm2 = vtype::min(zmm_t3, zmm_t4); + zmm_t zmm3 = vtype::max(zmm_t3, zmm_t4); + zmm[0] = bitonic_merge_zmm_32bit(zmm0); + zmm[1] = bitonic_merge_zmm_32bit(zmm1); + zmm[2] = bitonic_merge_zmm_32bit(zmm2); + zmm[3] = bitonic_merge_zmm_32bit(zmm3); +} + +template +X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_32bit(zmm_t *zmm) { + zmm_t zmm4r = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), zmm[4]); + zmm_t zmm5r = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), zmm[5]); + zmm_t zmm6r = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), zmm[6]); + zmm_t zmm7r = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), zmm[7]); + zmm_t zmm_t1 = vtype::min(zmm[0], zmm7r); + zmm_t zmm_t2 = vtype::min(zmm[1], zmm6r); + zmm_t zmm_t3 = vtype::min(zmm[2], zmm5r); + zmm_t zmm_t4 = vtype::min(zmm[3], zmm4r); + zmm_t zmm_t5 = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), + vtype::max(zmm[3], zmm4r)); + zmm_t zmm_t6 = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), + vtype::max(zmm[2], zmm5r)); + zmm_t zmm_t7 = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), + vtype::max(zmm[1], zmm6r)); + zmm_t zmm_t8 = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), + vtype::max(zmm[0], zmm7r)); + COEX(zmm_t1, zmm_t3); + COEX(zmm_t2, zmm_t4); + COEX(zmm_t5, zmm_t7); + COEX(zmm_t6, zmm_t8); + COEX(zmm_t1, zmm_t2); + COEX(zmm_t3, zmm_t4); + COEX(zmm_t5, zmm_t6); + COEX(zmm_t7, zmm_t8); + zmm[0] = bitonic_merge_zmm_32bit(zmm_t1); + zmm[1] = bitonic_merge_zmm_32bit(zmm_t2); + zmm[2] = bitonic_merge_zmm_32bit(zmm_t3); + zmm[3] = bitonic_merge_zmm_32bit(zmm_t4); + zmm[4] = bitonic_merge_zmm_32bit(zmm_t5); + zmm[5] = bitonic_merge_zmm_32bit(zmm_t6); + zmm[6] = bitonic_merge_zmm_32bit(zmm_t7); + zmm[7] = bitonic_merge_zmm_32bit(zmm_t8); +} + +template +X86_SIMD_SORT_INLINE void sort_16_32bit(type_t *arr, int32_t N) { + typename vtype::opmask_t load_mask = (0x0001 << N) - 0x0001; + typename vtype::zmm_t zmm = + vtype::mask_loadu(vtype::zmm_max(), load_mask, arr); + vtype::mask_storeu(arr, load_mask, sort_zmm_32bit(zmm)); +} + +template +X86_SIMD_SORT_INLINE void sort_32_32bit(type_t *arr, int32_t N) { + if (N <= 16) { + sort_16_32bit(arr, N); + return; + } + using zmm_t = typename vtype::zmm_t; + zmm_t zmm1 = vtype::loadu(arr); + typename vtype::opmask_t load_mask = (0x0001 << (N - 16)) - 0x0001; + zmm_t zmm2 = vtype::mask_loadu(vtype::zmm_max(), load_mask, arr + 16); + zmm1 = sort_zmm_32bit(zmm1); + zmm2 = sort_zmm_32bit(zmm2); + bitonic_merge_two_zmm_32bit(&zmm1, &zmm2); + vtype::storeu(arr, zmm1); + vtype::mask_storeu(arr + 16, load_mask, zmm2); +} + +template +X86_SIMD_SORT_INLINE void sort_64_32bit(type_t *arr, int32_t N) { + if (N <= 32) { + sort_32_32bit(arr, N); + return; + } + using zmm_t = typename vtype::zmm_t; + using opmask_t = typename vtype::opmask_t; + zmm_t zmm[4]; + zmm[0] = vtype::loadu(arr); + zmm[1] = vtype::loadu(arr + 16); + opmask_t load_mask1 = 0xFFFF, load_mask2 = 0xFFFF; + uint64_t combined_mask = (0x1ull << (N - 32)) - 0x1ull; + load_mask1 &= combined_mask & 0xFFFF; + load_mask2 &= (combined_mask >> 16) & 0xFFFF; + zmm[2] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 32); + zmm[3] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 48); + zmm[0] = sort_zmm_32bit(zmm[0]); + zmm[1] = sort_zmm_32bit(zmm[1]); + zmm[2] = sort_zmm_32bit(zmm[2]); + zmm[3] = sort_zmm_32bit(zmm[3]); + bitonic_merge_two_zmm_32bit(&zmm[0], &zmm[1]); + bitonic_merge_two_zmm_32bit(&zmm[2], &zmm[3]); + bitonic_merge_four_zmm_32bit(zmm); + vtype::storeu(arr, zmm[0]); + vtype::storeu(arr + 16, zmm[1]); + vtype::mask_storeu(arr + 32, load_mask1, zmm[2]); + vtype::mask_storeu(arr + 48, load_mask2, zmm[3]); +} + +template +X86_SIMD_SORT_INLINE void sort_128_32bit(type_t *arr, int32_t N) { + if (N <= 64) { + sort_64_32bit(arr, N); + return; + } + using zmm_t = typename vtype::zmm_t; + using opmask_t = typename vtype::opmask_t; + zmm_t zmm[8]; + zmm[0] = vtype::loadu(arr); + zmm[1] = vtype::loadu(arr + 16); + zmm[2] = vtype::loadu(arr + 32); + zmm[3] = vtype::loadu(arr + 48); + zmm[0] = sort_zmm_32bit(zmm[0]); + zmm[1] = sort_zmm_32bit(zmm[1]); + zmm[2] = sort_zmm_32bit(zmm[2]); + zmm[3] = sort_zmm_32bit(zmm[3]); + opmask_t load_mask1 = 0xFFFF, load_mask2 = 0xFFFF; + opmask_t load_mask3 = 0xFFFF, load_mask4 = 0xFFFF; + if (N != 128) { + uint64_t combined_mask = (0x1ull << (N - 64)) - 0x1ull; + load_mask1 &= combined_mask & 0xFFFF; + load_mask2 &= (combined_mask >> 16) & 0xFFFF; + load_mask3 &= (combined_mask >> 32) & 0xFFFF; + load_mask4 &= (combined_mask >> 48) & 0xFFFF; + } + zmm[4] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 64); + zmm[5] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 80); + zmm[6] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, arr + 96); + zmm[7] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, arr + 112); + zmm[4] = sort_zmm_32bit(zmm[4]); + zmm[5] = sort_zmm_32bit(zmm[5]); + zmm[6] = sort_zmm_32bit(zmm[6]); + zmm[7] = sort_zmm_32bit(zmm[7]); + bitonic_merge_two_zmm_32bit(&zmm[0], &zmm[1]); + bitonic_merge_two_zmm_32bit(&zmm[2], &zmm[3]); + bitonic_merge_two_zmm_32bit(&zmm[4], &zmm[5]); + bitonic_merge_two_zmm_32bit(&zmm[6], &zmm[7]); + bitonic_merge_four_zmm_32bit(zmm); + bitonic_merge_four_zmm_32bit(zmm + 4); + bitonic_merge_eight_zmm_32bit(zmm); + vtype::storeu(arr, zmm[0]); + vtype::storeu(arr + 16, zmm[1]); + vtype::storeu(arr + 32, zmm[2]); + vtype::storeu(arr + 48, zmm[3]); + vtype::mask_storeu(arr + 64, load_mask1, zmm[4]); + vtype::mask_storeu(arr + 80, load_mask2, zmm[5]); + vtype::mask_storeu(arr + 96, load_mask3, zmm[6]); + vtype::mask_storeu(arr + 112, load_mask4, zmm[7]); +} + + +template +static void qsort_32bit_(type_t *arr, int64_t left, int64_t right, + int64_t max_iters) { + /* + * Resort to std::sort if quicksort isnt making any progress + */ + if (max_iters <= 0) { + std::sort(arr + left, arr + right + 1); + return; + } + /* + * Base case: use bitonic networks to sort arrays <= 128 + */ + if (right + 1 - left <= 128) { + sort_128_32bit(arr + left, (int32_t)(right + 1 - left)); + return; + } + + type_t pivot = get_pivot_scalar(arr, left, right); + type_t smallest = vtype::type_max(); + type_t biggest = vtype::type_min(); + int64_t pivot_index = partition_avx512_unrolled( + arr, left, right + 1, pivot, &smallest, &biggest, false); + if (pivot != smallest) + qsort_32bit_(arr, left, pivot_index - 1, max_iters - 1); + if (pivot != biggest) + qsort_32bit_(arr, pivot_index, right, max_iters - 1); +} + +template <> +void inline avx512_qsort(int32_t *arr, int64_t fromIndex, int64_t toIndex) { + int64_t arrsize = toIndex - fromIndex; + if (arrsize > 1) { + qsort_32bit_, int32_t>(arr, fromIndex, toIndex - 1, + 2 * (int64_t)log2(arrsize)); + } +} + +template <> +void inline avx512_qsort(float *arr, int64_t fromIndex, int64_t toIndex) { + int64_t arrsize = toIndex - fromIndex; + if (arrsize > 1) { + qsort_32bit_, float>(arr, fromIndex, toIndex - 1, + 2 * (int64_t)log2(arrsize)); + } +} + +#endif // AVX512_QSORT_32BIT diff --git a/src/java.base/linux/native/libsimdsort/avx512-64bit-common.h b/src/java.base/linux/native/libsimdsort/avx512-64bit-common.h new file mode 100644 index 0000000000000..9993cd22e6377 --- /dev/null +++ b/src/java.base/linux/native/libsimdsort/avx512-64bit-common.h @@ -0,0 +1,212 @@ +/* + * Copyright (c) 2021, 2023, Intel Corporation. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + * + */ + +// This implementation is based on x86-simd-sort(https://github.com/intel/x86-simd-sort) + +#ifndef AVX512_64BIT_COMMON +#define AVX512_64BIT_COMMON +#include "avx512-common-qsort.h" + +/* + * Constants used in sorting 8 elements in a ZMM registers. Based on Bitonic + * sorting network (see + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) + */ +// ZMM 7, 6, 5, 4, 3, 2, 1, 0 +#define NETWORK_64BIT_1 4, 5, 6, 7, 0, 1, 2, 3 +#define NETWORK_64BIT_2 0, 1, 2, 3, 4, 5, 6, 7 +#define NETWORK_64BIT_3 5, 4, 7, 6, 1, 0, 3, 2 +#define NETWORK_64BIT_4 3, 2, 1, 0, 7, 6, 5, 4 + +template <> +struct zmm_vector { + using type_t = int64_t; + using zmm_t = __m512i; + using zmmi_t = __m512i; + using ymm_t = __m512i; + using opmask_t = __mmask8; + static const uint8_t numlanes = 8; + + static type_t type_max() { return X86_SIMD_SORT_MAX_INT64; } + static type_t type_min() { return X86_SIMD_SORT_MIN_INT64; } + static zmm_t zmm_max() { + return _mm512_set1_epi64(type_max()); + } // TODO: this should broadcast bits as is? + + static zmmi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, + int v8) { + return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); + } + static opmask_t kxor_opmask(opmask_t x, opmask_t y) { + return _kxor_mask8(x, y); + } + static opmask_t knot_opmask(opmask_t x) { return _knot_mask8(x); } + static opmask_t le(zmm_t x, zmm_t y) { + return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_LE); + } + static opmask_t ge(zmm_t x, zmm_t y) { + return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_NLT); + } + static opmask_t gt(zmm_t x, zmm_t y) { + return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_GT); + } + static opmask_t eq(zmm_t x, zmm_t y) { + return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_EQ); + } + template + static zmm_t mask_i64gather(zmm_t src, opmask_t mask, __m512i index, + void const *base) { + return _mm512_mask_i64gather_epi64(src, mask, index, base, scale); + } + template + static zmm_t i64gather(__m512i index, void const *base) { + return _mm512_i64gather_epi64(index, base, scale); + } + static zmm_t loadu(void const *mem) { return _mm512_loadu_si512(mem); } + static zmm_t max(zmm_t x, zmm_t y) { return _mm512_max_epi64(x, y); } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) { + return _mm512_mask_compressstoreu_epi64(mem, mask, x); + } + static zmm_t maskz_loadu(opmask_t mask, void const *mem) { + return _mm512_maskz_loadu_epi64(mask, mem); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) { + return _mm512_mask_loadu_epi64(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) { + return _mm512_mask_mov_epi64(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) { + return _mm512_mask_storeu_epi64(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) { return _mm512_min_epi64(x, y); } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) { + return _mm512_permutexvar_epi64(idx, zmm); + } + static type_t reducemax(zmm_t v) { return _mm512_reduce_max_epi64(v); } + static type_t reducemin(zmm_t v) { return _mm512_reduce_min_epi64(v); } + static zmm_t set1(type_t v) { return _mm512_set1_epi64(v); } + template + static zmm_t shuffle(zmm_t zmm) { + __m512d temp = _mm512_castsi512_pd(zmm); + return _mm512_castpd_si512( + _mm512_shuffle_pd(temp, temp, (_MM_PERM_ENUM)mask)); + } + static void storeu(void *mem, zmm_t x) { _mm512_storeu_si512(mem, x); } +}; +template <> +struct zmm_vector { + using type_t = double; + using zmm_t = __m512d; + using zmmi_t = __m512i; + using ymm_t = __m512d; + using opmask_t = __mmask8; + static const uint8_t numlanes = 8; + + static type_t type_max() { return X86_SIMD_SORT_INFINITY; } + static type_t type_min() { return -X86_SIMD_SORT_INFINITY; } + static zmm_t zmm_max() { return _mm512_set1_pd(type_max()); } + + static zmmi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, + int v8) { + return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); + } + + static zmm_t maskz_loadu(opmask_t mask, void const *mem) { + return _mm512_maskz_loadu_pd(mask, mem); + } + static opmask_t knot_opmask(opmask_t x) { return _knot_mask8(x); } + static opmask_t ge(zmm_t x, zmm_t y) { + return _mm512_cmp_pd_mask(x, y, _CMP_GE_OQ); + } + static opmask_t gt(zmm_t x, zmm_t y) { + return _mm512_cmp_pd_mask(x, y, _CMP_GT_OQ); + } + static opmask_t eq(zmm_t x, zmm_t y) { + return _mm512_cmp_pd_mask(x, y, _CMP_EQ_OQ); + } + template + static opmask_t fpclass(zmm_t x) { + return _mm512_fpclass_pd_mask(x, type); + } + template + static zmm_t mask_i64gather(zmm_t src, opmask_t mask, __m512i index, + void const *base) { + return _mm512_mask_i64gather_pd(src, mask, index, base, scale); + } + template + static zmm_t i64gather(__m512i index, void const *base) { + return _mm512_i64gather_pd(index, base, scale); + } + static zmm_t loadu(void const *mem) { return _mm512_loadu_pd(mem); } + static zmm_t max(zmm_t x, zmm_t y) { return _mm512_max_pd(x, y); } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) { + return _mm512_mask_compressstoreu_pd(mem, mask, x); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) { + return _mm512_mask_loadu_pd(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) { + return _mm512_mask_mov_pd(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) { + return _mm512_mask_storeu_pd(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) { return _mm512_min_pd(x, y); } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) { + return _mm512_permutexvar_pd(idx, zmm); + } + static type_t reducemax(zmm_t v) { return _mm512_reduce_max_pd(v); } + static type_t reducemin(zmm_t v) { return _mm512_reduce_min_pd(v); } + static zmm_t set1(type_t v) { return _mm512_set1_pd(v); } + template + static zmm_t shuffle(zmm_t zmm) { + return _mm512_shuffle_pd(zmm, zmm, (_MM_PERM_ENUM)mask); + } + static void storeu(void *mem, zmm_t x) { _mm512_storeu_pd(mem, x); } +}; + +/* + * Assumes zmm is random and performs a full sorting network defined in + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg + */ +template +X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t zmm) { + const typename vtype::zmmi_t rev_index = vtype::seti(NETWORK_64BIT_2); + zmm = cmp_merge( + zmm, vtype::template shuffle(zmm), 0xAA); + zmm = cmp_merge( + zmm, vtype::permutexvar(vtype::seti(NETWORK_64BIT_1), zmm), 0xCC); + zmm = cmp_merge( + zmm, vtype::template shuffle(zmm), 0xAA); + zmm = cmp_merge(zmm, vtype::permutexvar(rev_index, zmm), 0xF0); + zmm = cmp_merge( + zmm, vtype::permutexvar(vtype::seti(NETWORK_64BIT_3), zmm), 0xCC); + zmm = cmp_merge( + zmm, vtype::template shuffle(zmm), 0xAA); + return zmm; +} + + +#endif diff --git a/src/java.base/linux/native/libsimdsort/avx512-64bit-qsort.hpp b/src/java.base/linux/native/libsimdsort/avx512-64bit-qsort.hpp new file mode 100644 index 0000000000000..e28ebe19695de --- /dev/null +++ b/src/java.base/linux/native/libsimdsort/avx512-64bit-qsort.hpp @@ -0,0 +1,772 @@ +/* + * Copyright (c) 2021, 2023, Intel Corporation. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + * + */ + +// This implementation is based on x86-simd-sort(https://github.com/intel/x86-simd-sort) + +#ifndef AVX512_QSORT_64BIT +#define AVX512_QSORT_64BIT + +#include "avx512-64bit-common.h" + +// Assumes zmm is bitonic and performs a recursive half cleaner +template +X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_64bit(zmm_t zmm) { + // 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7 + zmm = cmp_merge( + zmm, vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), zmm), 0xF0); + // 2) half_cleaner[4] + zmm = cmp_merge( + zmm, vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), zmm), 0xCC); + // 3) half_cleaner[1] + zmm = cmp_merge( + zmm, vtype::template shuffle(zmm), 0xAA); + return zmm; +} +// Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner +template +X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(zmm_t &zmm1, + zmm_t &zmm2) { + const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + // 1) First step of a merging network: coex of zmm1 and zmm2 reversed + zmm2 = vtype::permutexvar(rev_index, zmm2); + zmm_t zmm3 = vtype::min(zmm1, zmm2); + zmm_t zmm4 = vtype::max(zmm1, zmm2); + // 2) Recursive half cleaner for each + zmm1 = bitonic_merge_zmm_64bit(zmm3); + zmm2 = bitonic_merge_zmm_64bit(zmm4); +} +// Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive +// half cleaner +template +X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *zmm) { + const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + // 1) First step of a merging network + zmm_t zmm2r = vtype::permutexvar(rev_index, zmm[2]); + zmm_t zmm3r = vtype::permutexvar(rev_index, zmm[3]); + zmm_t zmm_t1 = vtype::min(zmm[0], zmm3r); + zmm_t zmm_t2 = vtype::min(zmm[1], zmm2r); + // 2) Recursive half clearer: 16 + zmm_t zmm_t3 = vtype::permutexvar(rev_index, vtype::max(zmm[1], zmm2r)); + zmm_t zmm_t4 = vtype::permutexvar(rev_index, vtype::max(zmm[0], zmm3r)); + zmm_t zmm0 = vtype::min(zmm_t1, zmm_t2); + zmm_t zmm1 = vtype::max(zmm_t1, zmm_t2); + zmm_t zmm2 = vtype::min(zmm_t3, zmm_t4); + zmm_t zmm3 = vtype::max(zmm_t3, zmm_t4); + zmm[0] = bitonic_merge_zmm_64bit(zmm0); + zmm[1] = bitonic_merge_zmm_64bit(zmm1); + zmm[2] = bitonic_merge_zmm_64bit(zmm2); + zmm[3] = bitonic_merge_zmm_64bit(zmm3); +} +template +X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *zmm) { + const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + zmm_t zmm4r = vtype::permutexvar(rev_index, zmm[4]); + zmm_t zmm5r = vtype::permutexvar(rev_index, zmm[5]); + zmm_t zmm6r = vtype::permutexvar(rev_index, zmm[6]); + zmm_t zmm7r = vtype::permutexvar(rev_index, zmm[7]); + zmm_t zmm_t1 = vtype::min(zmm[0], zmm7r); + zmm_t zmm_t2 = vtype::min(zmm[1], zmm6r); + zmm_t zmm_t3 = vtype::min(zmm[2], zmm5r); + zmm_t zmm_t4 = vtype::min(zmm[3], zmm4r); + zmm_t zmm_t5 = vtype::permutexvar(rev_index, vtype::max(zmm[3], zmm4r)); + zmm_t zmm_t6 = vtype::permutexvar(rev_index, vtype::max(zmm[2], zmm5r)); + zmm_t zmm_t7 = vtype::permutexvar(rev_index, vtype::max(zmm[1], zmm6r)); + zmm_t zmm_t8 = vtype::permutexvar(rev_index, vtype::max(zmm[0], zmm7r)); + COEX(zmm_t1, zmm_t3); + COEX(zmm_t2, zmm_t4); + COEX(zmm_t5, zmm_t7); + COEX(zmm_t6, zmm_t8); + COEX(zmm_t1, zmm_t2); + COEX(zmm_t3, zmm_t4); + COEX(zmm_t5, zmm_t6); + COEX(zmm_t7, zmm_t8); + zmm[0] = bitonic_merge_zmm_64bit(zmm_t1); + zmm[1] = bitonic_merge_zmm_64bit(zmm_t2); + zmm[2] = bitonic_merge_zmm_64bit(zmm_t3); + zmm[3] = bitonic_merge_zmm_64bit(zmm_t4); + zmm[4] = bitonic_merge_zmm_64bit(zmm_t5); + zmm[5] = bitonic_merge_zmm_64bit(zmm_t6); + zmm[6] = bitonic_merge_zmm_64bit(zmm_t7); + zmm[7] = bitonic_merge_zmm_64bit(zmm_t8); +} +template +X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *zmm) { + const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + zmm_t zmm8r = vtype::permutexvar(rev_index, zmm[8]); + zmm_t zmm9r = vtype::permutexvar(rev_index, zmm[9]); + zmm_t zmm10r = vtype::permutexvar(rev_index, zmm[10]); + zmm_t zmm11r = vtype::permutexvar(rev_index, zmm[11]); + zmm_t zmm12r = vtype::permutexvar(rev_index, zmm[12]); + zmm_t zmm13r = vtype::permutexvar(rev_index, zmm[13]); + zmm_t zmm14r = vtype::permutexvar(rev_index, zmm[14]); + zmm_t zmm15r = vtype::permutexvar(rev_index, zmm[15]); + zmm_t zmm_t1 = vtype::min(zmm[0], zmm15r); + zmm_t zmm_t2 = vtype::min(zmm[1], zmm14r); + zmm_t zmm_t3 = vtype::min(zmm[2], zmm13r); + zmm_t zmm_t4 = vtype::min(zmm[3], zmm12r); + zmm_t zmm_t5 = vtype::min(zmm[4], zmm11r); + zmm_t zmm_t6 = vtype::min(zmm[5], zmm10r); + zmm_t zmm_t7 = vtype::min(zmm[6], zmm9r); + zmm_t zmm_t8 = vtype::min(zmm[7], zmm8r); + zmm_t zmm_t9 = vtype::permutexvar(rev_index, vtype::max(zmm[7], zmm8r)); + zmm_t zmm_t10 = vtype::permutexvar(rev_index, vtype::max(zmm[6], zmm9r)); + zmm_t zmm_t11 = vtype::permutexvar(rev_index, vtype::max(zmm[5], zmm10r)); + zmm_t zmm_t12 = vtype::permutexvar(rev_index, vtype::max(zmm[4], zmm11r)); + zmm_t zmm_t13 = vtype::permutexvar(rev_index, vtype::max(zmm[3], zmm12r)); + zmm_t zmm_t14 = vtype::permutexvar(rev_index, vtype::max(zmm[2], zmm13r)); + zmm_t zmm_t15 = vtype::permutexvar(rev_index, vtype::max(zmm[1], zmm14r)); + zmm_t zmm_t16 = vtype::permutexvar(rev_index, vtype::max(zmm[0], zmm15r)); + // Recusive half clear 16 zmm regs + COEX(zmm_t1, zmm_t5); + COEX(zmm_t2, zmm_t6); + COEX(zmm_t3, zmm_t7); + COEX(zmm_t4, zmm_t8); + COEX(zmm_t9, zmm_t13); + COEX(zmm_t10, zmm_t14); + COEX(zmm_t11, zmm_t15); + COEX(zmm_t12, zmm_t16); + // + COEX(zmm_t1, zmm_t3); + COEX(zmm_t2, zmm_t4); + COEX(zmm_t5, zmm_t7); + COEX(zmm_t6, zmm_t8); + COEX(zmm_t9, zmm_t11); + COEX(zmm_t10, zmm_t12); + COEX(zmm_t13, zmm_t15); + COEX(zmm_t14, zmm_t16); + // + COEX(zmm_t1, zmm_t2); + COEX(zmm_t3, zmm_t4); + COEX(zmm_t5, zmm_t6); + COEX(zmm_t7, zmm_t8); + COEX(zmm_t9, zmm_t10); + COEX(zmm_t11, zmm_t12); + COEX(zmm_t13, zmm_t14); + COEX(zmm_t15, zmm_t16); + // + zmm[0] = bitonic_merge_zmm_64bit(zmm_t1); + zmm[1] = bitonic_merge_zmm_64bit(zmm_t2); + zmm[2] = bitonic_merge_zmm_64bit(zmm_t3); + zmm[3] = bitonic_merge_zmm_64bit(zmm_t4); + zmm[4] = bitonic_merge_zmm_64bit(zmm_t5); + zmm[5] = bitonic_merge_zmm_64bit(zmm_t6); + zmm[6] = bitonic_merge_zmm_64bit(zmm_t7); + zmm[7] = bitonic_merge_zmm_64bit(zmm_t8); + zmm[8] = bitonic_merge_zmm_64bit(zmm_t9); + zmm[9] = bitonic_merge_zmm_64bit(zmm_t10); + zmm[10] = bitonic_merge_zmm_64bit(zmm_t11); + zmm[11] = bitonic_merge_zmm_64bit(zmm_t12); + zmm[12] = bitonic_merge_zmm_64bit(zmm_t13); + zmm[13] = bitonic_merge_zmm_64bit(zmm_t14); + zmm[14] = bitonic_merge_zmm_64bit(zmm_t15); + zmm[15] = bitonic_merge_zmm_64bit(zmm_t16); +} + +template +X86_SIMD_SORT_INLINE void bitonic_merge_32_zmm_64bit(zmm_t *zmm) { + const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + zmm_t zmm16r = vtype::permutexvar(rev_index, zmm[16]); + zmm_t zmm17r = vtype::permutexvar(rev_index, zmm[17]); + zmm_t zmm18r = vtype::permutexvar(rev_index, zmm[18]); + zmm_t zmm19r = vtype::permutexvar(rev_index, zmm[19]); + zmm_t zmm20r = vtype::permutexvar(rev_index, zmm[20]); + zmm_t zmm21r = vtype::permutexvar(rev_index, zmm[21]); + zmm_t zmm22r = vtype::permutexvar(rev_index, zmm[22]); + zmm_t zmm23r = vtype::permutexvar(rev_index, zmm[23]); + zmm_t zmm24r = vtype::permutexvar(rev_index, zmm[24]); + zmm_t zmm25r = vtype::permutexvar(rev_index, zmm[25]); + zmm_t zmm26r = vtype::permutexvar(rev_index, zmm[26]); + zmm_t zmm27r = vtype::permutexvar(rev_index, zmm[27]); + zmm_t zmm28r = vtype::permutexvar(rev_index, zmm[28]); + zmm_t zmm29r = vtype::permutexvar(rev_index, zmm[29]); + zmm_t zmm30r = vtype::permutexvar(rev_index, zmm[30]); + zmm_t zmm31r = vtype::permutexvar(rev_index, zmm[31]); + zmm_t zmm_t1 = vtype::min(zmm[0], zmm31r); + zmm_t zmm_t2 = vtype::min(zmm[1], zmm30r); + zmm_t zmm_t3 = vtype::min(zmm[2], zmm29r); + zmm_t zmm_t4 = vtype::min(zmm[3], zmm28r); + zmm_t zmm_t5 = vtype::min(zmm[4], zmm27r); + zmm_t zmm_t6 = vtype::min(zmm[5], zmm26r); + zmm_t zmm_t7 = vtype::min(zmm[6], zmm25r); + zmm_t zmm_t8 = vtype::min(zmm[7], zmm24r); + zmm_t zmm_t9 = vtype::min(zmm[8], zmm23r); + zmm_t zmm_t10 = vtype::min(zmm[9], zmm22r); + zmm_t zmm_t11 = vtype::min(zmm[10], zmm21r); + zmm_t zmm_t12 = vtype::min(zmm[11], zmm20r); + zmm_t zmm_t13 = vtype::min(zmm[12], zmm19r); + zmm_t zmm_t14 = vtype::min(zmm[13], zmm18r); + zmm_t zmm_t15 = vtype::min(zmm[14], zmm17r); + zmm_t zmm_t16 = vtype::min(zmm[15], zmm16r); + zmm_t zmm_t17 = vtype::permutexvar(rev_index, vtype::max(zmm[15], zmm16r)); + zmm_t zmm_t18 = vtype::permutexvar(rev_index, vtype::max(zmm[14], zmm17r)); + zmm_t zmm_t19 = vtype::permutexvar(rev_index, vtype::max(zmm[13], zmm18r)); + zmm_t zmm_t20 = vtype::permutexvar(rev_index, vtype::max(zmm[12], zmm19r)); + zmm_t zmm_t21 = vtype::permutexvar(rev_index, vtype::max(zmm[11], zmm20r)); + zmm_t zmm_t22 = vtype::permutexvar(rev_index, vtype::max(zmm[10], zmm21r)); + zmm_t zmm_t23 = vtype::permutexvar(rev_index, vtype::max(zmm[9], zmm22r)); + zmm_t zmm_t24 = vtype::permutexvar(rev_index, vtype::max(zmm[8], zmm23r)); + zmm_t zmm_t25 = vtype::permutexvar(rev_index, vtype::max(zmm[7], zmm24r)); + zmm_t zmm_t26 = vtype::permutexvar(rev_index, vtype::max(zmm[6], zmm25r)); + zmm_t zmm_t27 = vtype::permutexvar(rev_index, vtype::max(zmm[5], zmm26r)); + zmm_t zmm_t28 = vtype::permutexvar(rev_index, vtype::max(zmm[4], zmm27r)); + zmm_t zmm_t29 = vtype::permutexvar(rev_index, vtype::max(zmm[3], zmm28r)); + zmm_t zmm_t30 = vtype::permutexvar(rev_index, vtype::max(zmm[2], zmm29r)); + zmm_t zmm_t31 = vtype::permutexvar(rev_index, vtype::max(zmm[1], zmm30r)); + zmm_t zmm_t32 = vtype::permutexvar(rev_index, vtype::max(zmm[0], zmm31r)); + // Recusive half clear 16 zmm regs + COEX(zmm_t1, zmm_t9); + COEX(zmm_t2, zmm_t10); + COEX(zmm_t3, zmm_t11); + COEX(zmm_t4, zmm_t12); + COEX(zmm_t5, zmm_t13); + COEX(zmm_t6, zmm_t14); + COEX(zmm_t7, zmm_t15); + COEX(zmm_t8, zmm_t16); + COEX(zmm_t17, zmm_t25); + COEX(zmm_t18, zmm_t26); + COEX(zmm_t19, zmm_t27); + COEX(zmm_t20, zmm_t28); + COEX(zmm_t21, zmm_t29); + COEX(zmm_t22, zmm_t30); + COEX(zmm_t23, zmm_t31); + COEX(zmm_t24, zmm_t32); + // + COEX(zmm_t1, zmm_t5); + COEX(zmm_t2, zmm_t6); + COEX(zmm_t3, zmm_t7); + COEX(zmm_t4, zmm_t8); + COEX(zmm_t9, zmm_t13); + COEX(zmm_t10, zmm_t14); + COEX(zmm_t11, zmm_t15); + COEX(zmm_t12, zmm_t16); + COEX(zmm_t17, zmm_t21); + COEX(zmm_t18, zmm_t22); + COEX(zmm_t19, zmm_t23); + COEX(zmm_t20, zmm_t24); + COEX(zmm_t25, zmm_t29); + COEX(zmm_t26, zmm_t30); + COEX(zmm_t27, zmm_t31); + COEX(zmm_t28, zmm_t32); + // + COEX(zmm_t1, zmm_t3); + COEX(zmm_t2, zmm_t4); + COEX(zmm_t5, zmm_t7); + COEX(zmm_t6, zmm_t8); + COEX(zmm_t9, zmm_t11); + COEX(zmm_t10, zmm_t12); + COEX(zmm_t13, zmm_t15); + COEX(zmm_t14, zmm_t16); + COEX(zmm_t17, zmm_t19); + COEX(zmm_t18, zmm_t20); + COEX(zmm_t21, zmm_t23); + COEX(zmm_t22, zmm_t24); + COEX(zmm_t25, zmm_t27); + COEX(zmm_t26, zmm_t28); + COEX(zmm_t29, zmm_t31); + COEX(zmm_t30, zmm_t32); + // + COEX(zmm_t1, zmm_t2); + COEX(zmm_t3, zmm_t4); + COEX(zmm_t5, zmm_t6); + COEX(zmm_t7, zmm_t8); + COEX(zmm_t9, zmm_t10); + COEX(zmm_t11, zmm_t12); + COEX(zmm_t13, zmm_t14); + COEX(zmm_t15, zmm_t16); + COEX(zmm_t17, zmm_t18); + COEX(zmm_t19, zmm_t20); + COEX(zmm_t21, zmm_t22); + COEX(zmm_t23, zmm_t24); + COEX(zmm_t25, zmm_t26); + COEX(zmm_t27, zmm_t28); + COEX(zmm_t29, zmm_t30); + COEX(zmm_t31, zmm_t32); + // + zmm[0] = bitonic_merge_zmm_64bit(zmm_t1); + zmm[1] = bitonic_merge_zmm_64bit(zmm_t2); + zmm[2] = bitonic_merge_zmm_64bit(zmm_t3); + zmm[3] = bitonic_merge_zmm_64bit(zmm_t4); + zmm[4] = bitonic_merge_zmm_64bit(zmm_t5); + zmm[5] = bitonic_merge_zmm_64bit(zmm_t6); + zmm[6] = bitonic_merge_zmm_64bit(zmm_t7); + zmm[7] = bitonic_merge_zmm_64bit(zmm_t8); + zmm[8] = bitonic_merge_zmm_64bit(zmm_t9); + zmm[9] = bitonic_merge_zmm_64bit(zmm_t10); + zmm[10] = bitonic_merge_zmm_64bit(zmm_t11); + zmm[11] = bitonic_merge_zmm_64bit(zmm_t12); + zmm[12] = bitonic_merge_zmm_64bit(zmm_t13); + zmm[13] = bitonic_merge_zmm_64bit(zmm_t14); + zmm[14] = bitonic_merge_zmm_64bit(zmm_t15); + zmm[15] = bitonic_merge_zmm_64bit(zmm_t16); + zmm[16] = bitonic_merge_zmm_64bit(zmm_t17); + zmm[17] = bitonic_merge_zmm_64bit(zmm_t18); + zmm[18] = bitonic_merge_zmm_64bit(zmm_t19); + zmm[19] = bitonic_merge_zmm_64bit(zmm_t20); + zmm[20] = bitonic_merge_zmm_64bit(zmm_t21); + zmm[21] = bitonic_merge_zmm_64bit(zmm_t22); + zmm[22] = bitonic_merge_zmm_64bit(zmm_t23); + zmm[23] = bitonic_merge_zmm_64bit(zmm_t24); + zmm[24] = bitonic_merge_zmm_64bit(zmm_t25); + zmm[25] = bitonic_merge_zmm_64bit(zmm_t26); + zmm[26] = bitonic_merge_zmm_64bit(zmm_t27); + zmm[27] = bitonic_merge_zmm_64bit(zmm_t28); + zmm[28] = bitonic_merge_zmm_64bit(zmm_t29); + zmm[29] = bitonic_merge_zmm_64bit(zmm_t30); + zmm[30] = bitonic_merge_zmm_64bit(zmm_t31); + zmm[31] = bitonic_merge_zmm_64bit(zmm_t32); +} + +template +X86_SIMD_SORT_INLINE void sort_8_64bit(type_t *arr, int32_t N) { + typename vtype::opmask_t load_mask = (0x01 << N) - 0x01; + typename vtype::zmm_t zmm = + vtype::mask_loadu(vtype::zmm_max(), load_mask, arr); + vtype::mask_storeu(arr, load_mask, sort_zmm_64bit(zmm)); +} + +template +X86_SIMD_SORT_INLINE void sort_16_64bit(type_t *arr, int32_t N) { + if (N <= 8) { + sort_8_64bit(arr, N); + return; + } + using zmm_t = typename vtype::zmm_t; + zmm_t zmm1 = vtype::loadu(arr); + typename vtype::opmask_t load_mask = (0x01 << (N - 8)) - 0x01; + zmm_t zmm2 = vtype::mask_loadu(vtype::zmm_max(), load_mask, arr + 8); + zmm1 = sort_zmm_64bit(zmm1); + zmm2 = sort_zmm_64bit(zmm2); + bitonic_merge_two_zmm_64bit(zmm1, zmm2); + vtype::storeu(arr, zmm1); + vtype::mask_storeu(arr + 8, load_mask, zmm2); +} + +template +X86_SIMD_SORT_INLINE void sort_32_64bit(type_t *arr, int32_t N) { + if (N <= 16) { + sort_16_64bit(arr, N); + return; + } + using zmm_t = typename vtype::zmm_t; + using opmask_t = typename vtype::opmask_t; + zmm_t zmm[4]; + zmm[0] = vtype::loadu(arr); + zmm[1] = vtype::loadu(arr + 8); + opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; + uint64_t combined_mask = (0x1ull << (N - 16)) - 0x1ull; + load_mask1 = (combined_mask)&0xFF; + load_mask2 = (combined_mask >> 8) & 0xFF; + zmm[2] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 16); + zmm[3] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 24); + zmm[0] = sort_zmm_64bit(zmm[0]); + zmm[1] = sort_zmm_64bit(zmm[1]); + zmm[2] = sort_zmm_64bit(zmm[2]); + zmm[3] = sort_zmm_64bit(zmm[3]); + bitonic_merge_two_zmm_64bit(zmm[0], zmm[1]); + bitonic_merge_two_zmm_64bit(zmm[2], zmm[3]); + bitonic_merge_four_zmm_64bit(zmm); + vtype::storeu(arr, zmm[0]); + vtype::storeu(arr + 8, zmm[1]); + vtype::mask_storeu(arr + 16, load_mask1, zmm[2]); + vtype::mask_storeu(arr + 24, load_mask2, zmm[3]); +} + +template +X86_SIMD_SORT_INLINE void sort_64_64bit(type_t *arr, int32_t N) { + if (N <= 32) { + sort_32_64bit(arr, N); + return; + } + using zmm_t = typename vtype::zmm_t; + using opmask_t = typename vtype::opmask_t; + zmm_t zmm[8]; + zmm[0] = vtype::loadu(arr); + zmm[1] = vtype::loadu(arr + 8); + zmm[2] = vtype::loadu(arr + 16); + zmm[3] = vtype::loadu(arr + 24); + zmm[0] = sort_zmm_64bit(zmm[0]); + zmm[1] = sort_zmm_64bit(zmm[1]); + zmm[2] = sort_zmm_64bit(zmm[2]); + zmm[3] = sort_zmm_64bit(zmm[3]); + opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; + opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; + // N-32 >= 1 + uint64_t combined_mask = (0x1ull << (N - 32)) - 0x1ull; + load_mask1 = (combined_mask)&0xFF; + load_mask2 = (combined_mask >> 8) & 0xFF; + load_mask3 = (combined_mask >> 16) & 0xFF; + load_mask4 = (combined_mask >> 24) & 0xFF; + zmm[4] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 32); + zmm[5] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 40); + zmm[6] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, arr + 48); + zmm[7] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, arr + 56); + zmm[4] = sort_zmm_64bit(zmm[4]); + zmm[5] = sort_zmm_64bit(zmm[5]); + zmm[6] = sort_zmm_64bit(zmm[6]); + zmm[7] = sort_zmm_64bit(zmm[7]); + bitonic_merge_two_zmm_64bit(zmm[0], zmm[1]); + bitonic_merge_two_zmm_64bit(zmm[2], zmm[3]); + bitonic_merge_two_zmm_64bit(zmm[4], zmm[5]); + bitonic_merge_two_zmm_64bit(zmm[6], zmm[7]); + bitonic_merge_four_zmm_64bit(zmm); + bitonic_merge_four_zmm_64bit(zmm + 4); + bitonic_merge_eight_zmm_64bit(zmm); + vtype::storeu(arr, zmm[0]); + vtype::storeu(arr + 8, zmm[1]); + vtype::storeu(arr + 16, zmm[2]); + vtype::storeu(arr + 24, zmm[3]); + vtype::mask_storeu(arr + 32, load_mask1, zmm[4]); + vtype::mask_storeu(arr + 40, load_mask2, zmm[5]); + vtype::mask_storeu(arr + 48, load_mask3, zmm[6]); + vtype::mask_storeu(arr + 56, load_mask4, zmm[7]); +} + +template +X86_SIMD_SORT_INLINE void sort_128_64bit(type_t *arr, int32_t N) { + if (N <= 64) { + sort_64_64bit(arr, N); + return; + } + using zmm_t = typename vtype::zmm_t; + using opmask_t = typename vtype::opmask_t; + zmm_t zmm[16]; + zmm[0] = vtype::loadu(arr); + zmm[1] = vtype::loadu(arr + 8); + zmm[2] = vtype::loadu(arr + 16); + zmm[3] = vtype::loadu(arr + 24); + zmm[4] = vtype::loadu(arr + 32); + zmm[5] = vtype::loadu(arr + 40); + zmm[6] = vtype::loadu(arr + 48); + zmm[7] = vtype::loadu(arr + 56); + zmm[0] = sort_zmm_64bit(zmm[0]); + zmm[1] = sort_zmm_64bit(zmm[1]); + zmm[2] = sort_zmm_64bit(zmm[2]); + zmm[3] = sort_zmm_64bit(zmm[3]); + zmm[4] = sort_zmm_64bit(zmm[4]); + zmm[5] = sort_zmm_64bit(zmm[5]); + zmm[6] = sort_zmm_64bit(zmm[6]); + zmm[7] = sort_zmm_64bit(zmm[7]); + opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; + opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; + opmask_t load_mask5 = 0xFF, load_mask6 = 0xFF; + opmask_t load_mask7 = 0xFF, load_mask8 = 0xFF; + if (N != 128) { + uint64_t combined_mask = (0x1ull << (N - 64)) - 0x1ull; + load_mask1 = (combined_mask)&0xFF; + load_mask2 = (combined_mask >> 8) & 0xFF; + load_mask3 = (combined_mask >> 16) & 0xFF; + load_mask4 = (combined_mask >> 24) & 0xFF; + load_mask5 = (combined_mask >> 32) & 0xFF; + load_mask6 = (combined_mask >> 40) & 0xFF; + load_mask7 = (combined_mask >> 48) & 0xFF; + load_mask8 = (combined_mask >> 56) & 0xFF; + } + zmm[8] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 64); + zmm[9] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 72); + zmm[10] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, arr + 80); + zmm[11] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, arr + 88); + zmm[12] = vtype::mask_loadu(vtype::zmm_max(), load_mask5, arr + 96); + zmm[13] = vtype::mask_loadu(vtype::zmm_max(), load_mask6, arr + 104); + zmm[14] = vtype::mask_loadu(vtype::zmm_max(), load_mask7, arr + 112); + zmm[15] = vtype::mask_loadu(vtype::zmm_max(), load_mask8, arr + 120); + zmm[8] = sort_zmm_64bit(zmm[8]); + zmm[9] = sort_zmm_64bit(zmm[9]); + zmm[10] = sort_zmm_64bit(zmm[10]); + zmm[11] = sort_zmm_64bit(zmm[11]); + zmm[12] = sort_zmm_64bit(zmm[12]); + zmm[13] = sort_zmm_64bit(zmm[13]); + zmm[14] = sort_zmm_64bit(zmm[14]); + zmm[15] = sort_zmm_64bit(zmm[15]); + bitonic_merge_two_zmm_64bit(zmm[0], zmm[1]); + bitonic_merge_two_zmm_64bit(zmm[2], zmm[3]); + bitonic_merge_two_zmm_64bit(zmm[4], zmm[5]); + bitonic_merge_two_zmm_64bit(zmm[6], zmm[7]); + bitonic_merge_two_zmm_64bit(zmm[8], zmm[9]); + bitonic_merge_two_zmm_64bit(zmm[10], zmm[11]); + bitonic_merge_two_zmm_64bit(zmm[12], zmm[13]); + bitonic_merge_two_zmm_64bit(zmm[14], zmm[15]); + bitonic_merge_four_zmm_64bit(zmm); + bitonic_merge_four_zmm_64bit(zmm + 4); + bitonic_merge_four_zmm_64bit(zmm + 8); + bitonic_merge_four_zmm_64bit(zmm + 12); + bitonic_merge_eight_zmm_64bit(zmm); + bitonic_merge_eight_zmm_64bit(zmm + 8); + bitonic_merge_sixteen_zmm_64bit(zmm); + vtype::storeu(arr, zmm[0]); + vtype::storeu(arr + 8, zmm[1]); + vtype::storeu(arr + 16, zmm[2]); + vtype::storeu(arr + 24, zmm[3]); + vtype::storeu(arr + 32, zmm[4]); + vtype::storeu(arr + 40, zmm[5]); + vtype::storeu(arr + 48, zmm[6]); + vtype::storeu(arr + 56, zmm[7]); + vtype::mask_storeu(arr + 64, load_mask1, zmm[8]); + vtype::mask_storeu(arr + 72, load_mask2, zmm[9]); + vtype::mask_storeu(arr + 80, load_mask3, zmm[10]); + vtype::mask_storeu(arr + 88, load_mask4, zmm[11]); + vtype::mask_storeu(arr + 96, load_mask5, zmm[12]); + vtype::mask_storeu(arr + 104, load_mask6, zmm[13]); + vtype::mask_storeu(arr + 112, load_mask7, zmm[14]); + vtype::mask_storeu(arr + 120, load_mask8, zmm[15]); +} + +template +X86_SIMD_SORT_INLINE void sort_256_64bit(type_t *arr, int32_t N) { + if (N <= 128) { + sort_128_64bit(arr, N); + return; + } + using zmm_t = typename vtype::zmm_t; + using opmask_t = typename vtype::opmask_t; + zmm_t zmm[32]; + zmm[0] = vtype::loadu(arr); + zmm[1] = vtype::loadu(arr + 8); + zmm[2] = vtype::loadu(arr + 16); + zmm[3] = vtype::loadu(arr + 24); + zmm[4] = vtype::loadu(arr + 32); + zmm[5] = vtype::loadu(arr + 40); + zmm[6] = vtype::loadu(arr + 48); + zmm[7] = vtype::loadu(arr + 56); + zmm[8] = vtype::loadu(arr + 64); + zmm[9] = vtype::loadu(arr + 72); + zmm[10] = vtype::loadu(arr + 80); + zmm[11] = vtype::loadu(arr + 88); + zmm[12] = vtype::loadu(arr + 96); + zmm[13] = vtype::loadu(arr + 104); + zmm[14] = vtype::loadu(arr + 112); + zmm[15] = vtype::loadu(arr + 120); + zmm[0] = sort_zmm_64bit(zmm[0]); + zmm[1] = sort_zmm_64bit(zmm[1]); + zmm[2] = sort_zmm_64bit(zmm[2]); + zmm[3] = sort_zmm_64bit(zmm[3]); + zmm[4] = sort_zmm_64bit(zmm[4]); + zmm[5] = sort_zmm_64bit(zmm[5]); + zmm[6] = sort_zmm_64bit(zmm[6]); + zmm[7] = sort_zmm_64bit(zmm[7]); + zmm[8] = sort_zmm_64bit(zmm[8]); + zmm[9] = sort_zmm_64bit(zmm[9]); + zmm[10] = sort_zmm_64bit(zmm[10]); + zmm[11] = sort_zmm_64bit(zmm[11]); + zmm[12] = sort_zmm_64bit(zmm[12]); + zmm[13] = sort_zmm_64bit(zmm[13]); + zmm[14] = sort_zmm_64bit(zmm[14]); + zmm[15] = sort_zmm_64bit(zmm[15]); + opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; + opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; + opmask_t load_mask5 = 0xFF, load_mask6 = 0xFF; + opmask_t load_mask7 = 0xFF, load_mask8 = 0xFF; + opmask_t load_mask9 = 0xFF, load_mask10 = 0xFF; + opmask_t load_mask11 = 0xFF, load_mask12 = 0xFF; + opmask_t load_mask13 = 0xFF, load_mask14 = 0xFF; + opmask_t load_mask15 = 0xFF, load_mask16 = 0xFF; + if (N != 256) { + uint64_t combined_mask; + if (N < 192) { + combined_mask = (0x1ull << (N - 128)) - 0x1ull; + load_mask1 = (combined_mask)&0xFF; + load_mask2 = (combined_mask >> 8) & 0xFF; + load_mask3 = (combined_mask >> 16) & 0xFF; + load_mask4 = (combined_mask >> 24) & 0xFF; + load_mask5 = (combined_mask >> 32) & 0xFF; + load_mask6 = (combined_mask >> 40) & 0xFF; + load_mask7 = (combined_mask >> 48) & 0xFF; + load_mask8 = (combined_mask >> 56) & 0xFF; + load_mask9 = 0x00; + load_mask10 = 0x0; + load_mask11 = 0x00; + load_mask12 = 0x00; + load_mask13 = 0x00; + load_mask14 = 0x00; + load_mask15 = 0x00; + load_mask16 = 0x00; + } else { + combined_mask = (0x1ull << (N - 192)) - 0x1ull; + load_mask9 = (combined_mask)&0xFF; + load_mask10 = (combined_mask >> 8) & 0xFF; + load_mask11 = (combined_mask >> 16) & 0xFF; + load_mask12 = (combined_mask >> 24) & 0xFF; + load_mask13 = (combined_mask >> 32) & 0xFF; + load_mask14 = (combined_mask >> 40) & 0xFF; + load_mask15 = (combined_mask >> 48) & 0xFF; + load_mask16 = (combined_mask >> 56) & 0xFF; + } + } + zmm[16] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 128); + zmm[17] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 136); + zmm[18] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, arr + 144); + zmm[19] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, arr + 152); + zmm[20] = vtype::mask_loadu(vtype::zmm_max(), load_mask5, arr + 160); + zmm[21] = vtype::mask_loadu(vtype::zmm_max(), load_mask6, arr + 168); + zmm[22] = vtype::mask_loadu(vtype::zmm_max(), load_mask7, arr + 176); + zmm[23] = vtype::mask_loadu(vtype::zmm_max(), load_mask8, arr + 184); + if (N < 192) { + zmm[24] = vtype::zmm_max(); + zmm[25] = vtype::zmm_max(); + zmm[26] = vtype::zmm_max(); + zmm[27] = vtype::zmm_max(); + zmm[28] = vtype::zmm_max(); + zmm[29] = vtype::zmm_max(); + zmm[30] = vtype::zmm_max(); + zmm[31] = vtype::zmm_max(); + } else { + zmm[24] = vtype::mask_loadu(vtype::zmm_max(), load_mask9, arr + 192); + zmm[25] = vtype::mask_loadu(vtype::zmm_max(), load_mask10, arr + 200); + zmm[26] = vtype::mask_loadu(vtype::zmm_max(), load_mask11, arr + 208); + zmm[27] = vtype::mask_loadu(vtype::zmm_max(), load_mask12, arr + 216); + zmm[28] = vtype::mask_loadu(vtype::zmm_max(), load_mask13, arr + 224); + zmm[29] = vtype::mask_loadu(vtype::zmm_max(), load_mask14, arr + 232); + zmm[30] = vtype::mask_loadu(vtype::zmm_max(), load_mask15, arr + 240); + zmm[31] = vtype::mask_loadu(vtype::zmm_max(), load_mask16, arr + 248); + } + zmm[16] = sort_zmm_64bit(zmm[16]); + zmm[17] = sort_zmm_64bit(zmm[17]); + zmm[18] = sort_zmm_64bit(zmm[18]); + zmm[19] = sort_zmm_64bit(zmm[19]); + zmm[20] = sort_zmm_64bit(zmm[20]); + zmm[21] = sort_zmm_64bit(zmm[21]); + zmm[22] = sort_zmm_64bit(zmm[22]); + zmm[23] = sort_zmm_64bit(zmm[23]); + zmm[24] = sort_zmm_64bit(zmm[24]); + zmm[25] = sort_zmm_64bit(zmm[25]); + zmm[26] = sort_zmm_64bit(zmm[26]); + zmm[27] = sort_zmm_64bit(zmm[27]); + zmm[28] = sort_zmm_64bit(zmm[28]); + zmm[29] = sort_zmm_64bit(zmm[29]); + zmm[30] = sort_zmm_64bit(zmm[30]); + zmm[31] = sort_zmm_64bit(zmm[31]); + bitonic_merge_two_zmm_64bit(zmm[0], zmm[1]); + bitonic_merge_two_zmm_64bit(zmm[2], zmm[3]); + bitonic_merge_two_zmm_64bit(zmm[4], zmm[5]); + bitonic_merge_two_zmm_64bit(zmm[6], zmm[7]); + bitonic_merge_two_zmm_64bit(zmm[8], zmm[9]); + bitonic_merge_two_zmm_64bit(zmm[10], zmm[11]); + bitonic_merge_two_zmm_64bit(zmm[12], zmm[13]); + bitonic_merge_two_zmm_64bit(zmm[14], zmm[15]); + bitonic_merge_two_zmm_64bit(zmm[16], zmm[17]); + bitonic_merge_two_zmm_64bit(zmm[18], zmm[19]); + bitonic_merge_two_zmm_64bit(zmm[20], zmm[21]); + bitonic_merge_two_zmm_64bit(zmm[22], zmm[23]); + bitonic_merge_two_zmm_64bit(zmm[24], zmm[25]); + bitonic_merge_two_zmm_64bit(zmm[26], zmm[27]); + bitonic_merge_two_zmm_64bit(zmm[28], zmm[29]); + bitonic_merge_two_zmm_64bit(zmm[30], zmm[31]); + bitonic_merge_four_zmm_64bit(zmm); + bitonic_merge_four_zmm_64bit(zmm + 4); + bitonic_merge_four_zmm_64bit(zmm + 8); + bitonic_merge_four_zmm_64bit(zmm + 12); + bitonic_merge_four_zmm_64bit(zmm + 16); + bitonic_merge_four_zmm_64bit(zmm + 20); + bitonic_merge_four_zmm_64bit(zmm + 24); + bitonic_merge_four_zmm_64bit(zmm + 28); + bitonic_merge_eight_zmm_64bit(zmm); + bitonic_merge_eight_zmm_64bit(zmm + 8); + bitonic_merge_eight_zmm_64bit(zmm + 16); + bitonic_merge_eight_zmm_64bit(zmm + 24); + bitonic_merge_sixteen_zmm_64bit(zmm); + bitonic_merge_sixteen_zmm_64bit(zmm + 16); + bitonic_merge_32_zmm_64bit(zmm); + vtype::storeu(arr, zmm[0]); + vtype::storeu(arr + 8, zmm[1]); + vtype::storeu(arr + 16, zmm[2]); + vtype::storeu(arr + 24, zmm[3]); + vtype::storeu(arr + 32, zmm[4]); + vtype::storeu(arr + 40, zmm[5]); + vtype::storeu(arr + 48, zmm[6]); + vtype::storeu(arr + 56, zmm[7]); + vtype::storeu(arr + 64, zmm[8]); + vtype::storeu(arr + 72, zmm[9]); + vtype::storeu(arr + 80, zmm[10]); + vtype::storeu(arr + 88, zmm[11]); + vtype::storeu(arr + 96, zmm[12]); + vtype::storeu(arr + 104, zmm[13]); + vtype::storeu(arr + 112, zmm[14]); + vtype::storeu(arr + 120, zmm[15]); + vtype::mask_storeu(arr + 128, load_mask1, zmm[16]); + vtype::mask_storeu(arr + 136, load_mask2, zmm[17]); + vtype::mask_storeu(arr + 144, load_mask3, zmm[18]); + vtype::mask_storeu(arr + 152, load_mask4, zmm[19]); + vtype::mask_storeu(arr + 160, load_mask5, zmm[20]); + vtype::mask_storeu(arr + 168, load_mask6, zmm[21]); + vtype::mask_storeu(arr + 176, load_mask7, zmm[22]); + vtype::mask_storeu(arr + 184, load_mask8, zmm[23]); + if (N > 192) { + vtype::mask_storeu(arr + 192, load_mask9, zmm[24]); + vtype::mask_storeu(arr + 200, load_mask10, zmm[25]); + vtype::mask_storeu(arr + 208, load_mask11, zmm[26]); + vtype::mask_storeu(arr + 216, load_mask12, zmm[27]); + vtype::mask_storeu(arr + 224, load_mask13, zmm[28]); + vtype::mask_storeu(arr + 232, load_mask14, zmm[29]); + vtype::mask_storeu(arr + 240, load_mask15, zmm[30]); + vtype::mask_storeu(arr + 248, load_mask16, zmm[31]); + } +} + +template +static void qsort_64bit_(type_t *arr, int64_t left, int64_t right, + int64_t max_iters) { + /* + * Resort to std::sort if quicksort isnt making any progress + */ + if (max_iters <= 0) { + std::sort(arr + left, arr + right + 1); + return; + } + /* + * Base case: use bitonic networks to sort arrays <= 128 + */ + if (right + 1 - left <= 256) { + sort_256_64bit(arr + left, (int32_t)(right + 1 - left)); + return; + } + + type_t pivot = get_pivot_scalar(arr, left, right); + type_t smallest = vtype::type_max(); + type_t biggest = vtype::type_min(); + int64_t pivot_index = partition_avx512_unrolled( + arr, left, right + 1, pivot, &smallest, &biggest, false); + if (pivot != smallest) + qsort_64bit_(arr, left, pivot_index - 1, max_iters - 1); + if (pivot != biggest) + qsort_64bit_(arr, pivot_index, right, max_iters - 1); +} + +template <> +void inline avx512_qsort(int64_t *arr, int64_t fromIndex, int64_t toIndex) { + int64_t arrsize = toIndex - fromIndex; + if (arrsize > 1) { + qsort_64bit_, int64_t>(arr, fromIndex, toIndex - 1, + 2 * (int64_t)log2(arrsize)); + } +} + +template <> +void inline avx512_qsort(double *arr, int64_t fromIndex, int64_t toIndex) { + int64_t arrsize = toIndex - fromIndex; + if (arrsize > 1) { + qsort_64bit_, double>(arr, fromIndex, toIndex - 1, + 2 * (int64_t)log2(arrsize)); + } +} + +#endif // AVX512_QSORT_64BIT diff --git a/src/java.base/linux/native/libsimdsort/avx512-common-qsort.h b/src/java.base/linux/native/libsimdsort/avx512-common-qsort.h new file mode 100644 index 0000000000000..b008bcd54b80c --- /dev/null +++ b/src/java.base/linux/native/libsimdsort/avx512-common-qsort.h @@ -0,0 +1,474 @@ +/* + * Copyright (c) 2021, 2023, Intel Corporation. All rights reserved. + * Copyright (c) 2021 Serge Sans Paille. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + * + */ + +// This implementation is based on x86-simd-sort(https://github.com/intel/x86-simd-sort) +#ifndef AVX512_QSORT_COMMON +#define AVX512_QSORT_COMMON + +/* + * Quicksort using AVX-512. The ideas and code are based on these two research + * papers [1] and [2]. On a high level, the idea is to vectorize quicksort + * partitioning using AVX-512 compressstore instructions. If the array size is + * < 128, then use Bitonic sorting network implemented on 512-bit registers. + * The precise network definitions depend on the dtype and are defined in + * separate files: avx512-16bit-qsort.hpp, avx512-32bit-qsort.hpp and + * avx512-64bit-qsort.hpp. Article [4] is a good resource for bitonic sorting + * network. The core implementations of the vectorized qsort functions + * avx512_qsort(T*, int64_t) are modified versions of avx2 quicksort + * presented in the paper [2] and source code associated with that paper [3]. + * + * [1] Fast and Robust Vectorized In-Place Sorting of Primitive Types + * https://drops.dagstuhl.de/opus/volltexte/2021/13775/ + * + * [2] A Novel Hybrid Quicksort Algorithm Vectorized using AVX-512 on Intel + * Skylake https://arxiv.org/pdf/1704.08579.pdf + * + * [3] https://github.com/simd-sorting/fast-and-robust: SPDX-License-Identifier: + * MIT + * + * [4] + * http://mitp-content-server.mit.edu:18180/books/content/sectbyfn?collid=books_pres_0&fn=Chapter%2027.pdf&id=8030 + * + */ + +#include +#include +#include +#include +#include +#include + +#define X86_SIMD_SORT_INFINITY std::numeric_limits::infinity() +#define X86_SIMD_SORT_INFINITYF std::numeric_limits::infinity() +#define X86_SIMD_SORT_INFINITYH 0x7c00 +#define X86_SIMD_SORT_NEGINFINITYH 0xfc00 +#define X86_SIMD_SORT_MAX_UINT16 std::numeric_limits::max() +#define X86_SIMD_SORT_MAX_INT16 std::numeric_limits::max() +#define X86_SIMD_SORT_MIN_INT16 std::numeric_limits::min() +#define X86_SIMD_SORT_MAX_UINT32 std::numeric_limits::max() +#define X86_SIMD_SORT_MAX_INT32 std::numeric_limits::max() +#define X86_SIMD_SORT_MIN_INT32 std::numeric_limits::min() +#define X86_SIMD_SORT_MAX_UINT64 std::numeric_limits::max() +#define X86_SIMD_SORT_MAX_INT64 std::numeric_limits::max() +#define X86_SIMD_SORT_MIN_INT64 std::numeric_limits::min() +#define ZMM_MAX_DOUBLE _mm512_set1_pd(X86_SIMD_SORT_INFINITY) +#define ZMM_MAX_UINT64 _mm512_set1_epi64(X86_SIMD_SORT_MAX_UINT64) +#define ZMM_MAX_INT64 _mm512_set1_epi64(X86_SIMD_SORT_MAX_INT64) +#define ZMM_MAX_FLOAT _mm512_set1_ps(X86_SIMD_SORT_INFINITYF) +#define ZMM_MAX_UINT _mm512_set1_epi32(X86_SIMD_SORT_MAX_UINT32) +#define ZMM_MAX_INT _mm512_set1_epi32(X86_SIMD_SORT_MAX_INT32) +#define ZMM_MAX_HALF _mm512_set1_epi16(X86_SIMD_SORT_INFINITYH) +#define YMM_MAX_HALF _mm256_set1_epi16(X86_SIMD_SORT_INFINITYH) +#define ZMM_MAX_UINT16 _mm512_set1_epi16(X86_SIMD_SORT_MAX_UINT16) +#define ZMM_MAX_INT16 _mm512_set1_epi16(X86_SIMD_SORT_MAX_INT16) +#define SHUFFLE_MASK(a, b, c, d) (a << 6) | (b << 4) | (c << 2) | d + +#ifdef _MSC_VER +#define X86_SIMD_SORT_INLINE static inline +#define X86_SIMD_SORT_FINLINE static __forceinline +#elif defined(__CYGWIN__) +/* + * Force inline in cygwin to work around a compiler bug. See + * https://github.com/numpy/numpy/pull/22315#issuecomment-1267757584 + */ +#define X86_SIMD_SORT_INLINE static __attribute__((always_inline)) +#define X86_SIMD_SORT_FINLINE static __attribute__((always_inline)) +#elif defined(__GNUC__) +#define X86_SIMD_SORT_INLINE static inline +#define X86_SIMD_SORT_FINLINE static __attribute__((always_inline)) +#else +#define X86_SIMD_SORT_INLINE static +#define X86_SIMD_SORT_FINLINE static +#endif + +#define LIKELY(x) __builtin_expect((x), 1) +#define UNLIKELY(x) __builtin_expect((x), 0) + +template +struct zmm_vector; + +template +struct ymm_vector; + +// Regular quicksort routines: +template +void avx512_qsort(T *arr, int64_t arrsize); + +template +void inline avx512_qsort(T *arr, int64_t from_index, int64_t to_index); + +template +bool is_a_nan(T elem) { + return std::isnan(elem); +} + +template +X86_SIMD_SORT_INLINE T get_pivot_scalar(T *arr, const int64_t left, const int64_t right) { + // median of 8 equally spaced elements + int64_t NUM_ELEMENTS = 8; + int64_t MID = NUM_ELEMENTS / 2; + int64_t size = (right - left) / NUM_ELEMENTS; + T temp[NUM_ELEMENTS]; + for (int64_t i = 0; i < NUM_ELEMENTS; i++) temp[i] = arr[left + (i * size)]; + std::sort(temp, temp + NUM_ELEMENTS); + return temp[MID]; +} + +template +bool comparison_func_ge(const T &a, const T &b) { + return a < b; +} + +template +bool comparison_func_gt(const T &a, const T &b) { + return a <= b; +} + +/* + * COEX == Compare and Exchange two registers by swapping min and max values + */ +template +static void COEX(mm_t &a, mm_t &b) { + mm_t temp = a; + a = vtype::min(a, b); + b = vtype::max(temp, b); +} +template +static inline zmm_t cmp_merge(zmm_t in1, zmm_t in2, opmask_t mask) { + zmm_t min = vtype::min(in2, in1); + zmm_t max = vtype::max(in2, in1); + return vtype::mask_mov(min, mask, max); // 0 -> min, 1 -> max +} +/* + * Parition one ZMM register based on the pivot and returns the + * number of elements that are greater than or equal to the pivot. + */ +template +static inline int32_t partition_vec(type_t *arr, int64_t left, int64_t right, + const zmm_t curr_vec, const zmm_t pivot_vec, + zmm_t *smallest_vec, zmm_t *biggest_vec, bool use_gt) { + /* which elements are larger than or equal to the pivot */ + typename vtype::opmask_t mask; + if (use_gt) mask = vtype::gt(curr_vec, pivot_vec); + else mask = vtype::ge(curr_vec, pivot_vec); + //mask = vtype::ge(curr_vec, pivot_vec); + int32_t amount_ge_pivot = _mm_popcnt_u32((int32_t)mask); + vtype::mask_compressstoreu(arr + left, vtype::knot_opmask(mask), + curr_vec); + vtype::mask_compressstoreu(arr + right - amount_ge_pivot, mask, + curr_vec); + *smallest_vec = vtype::min(curr_vec, *smallest_vec); + *biggest_vec = vtype::max(curr_vec, *biggest_vec); + return amount_ge_pivot; +} +/* + * Parition an array based on the pivot and returns the index of the + * first element that is greater than or equal to the pivot. + */ +template +static inline int64_t partition_avx512(type_t *arr, int64_t left, int64_t right, + type_t pivot, type_t *smallest, + type_t *biggest, bool use_gt) { + auto comparison_func = use_gt ? comparison_func_gt : comparison_func_ge; + /* make array length divisible by vtype::numlanes , shortening the array */ + for (int32_t i = (right - left) % vtype::numlanes; i > 0; --i) { + *smallest = std::min(*smallest, arr[left], comparison_func); + *biggest = std::max(*biggest, arr[left], comparison_func); + if (!comparison_func(arr[left], pivot)) { + std::swap(arr[left], arr[--right]); + } else { + ++left; + } + } + + if (left == right) + return left; /* less than vtype::numlanes elements in the array */ + + using zmm_t = typename vtype::zmm_t; + zmm_t pivot_vec = vtype::set1(pivot); + zmm_t min_vec = vtype::set1(*smallest); + zmm_t max_vec = vtype::set1(*biggest); + + if (right - left == vtype::numlanes) { + zmm_t vec = vtype::loadu(arr + left); + int32_t amount_ge_pivot = + partition_vec(arr, left, left + vtype::numlanes, vec, + pivot_vec, &min_vec, &max_vec, use_gt); + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return left + (vtype::numlanes - amount_ge_pivot); + } + + // first and last vtype::numlanes values are partitioned at the end + zmm_t vec_left = vtype::loadu(arr + left); + zmm_t vec_right = vtype::loadu(arr + (right - vtype::numlanes)); + // store points of the vectors + int64_t r_store = right - vtype::numlanes; + int64_t l_store = left; + // indices for loading the elements + left += vtype::numlanes; + right -= vtype::numlanes; + while (right - left != 0) { + zmm_t curr_vec; + /* + * if fewer elements are stored on the right side of the array, + * then next elements are loaded from the right side, + * otherwise from the left side + */ + if ((r_store + vtype::numlanes) - right < left - l_store) { + right -= vtype::numlanes; + curr_vec = vtype::loadu(arr + right); + } else { + curr_vec = vtype::loadu(arr + left); + left += vtype::numlanes; + } + // partition the current vector and save it on both sides of the array + int32_t amount_ge_pivot = + partition_vec(arr, l_store, r_store + vtype::numlanes, + curr_vec, pivot_vec, &min_vec, &max_vec, use_gt); + ; + r_store -= amount_ge_pivot; + l_store += (vtype::numlanes - amount_ge_pivot); + } + + /* partition and save vec_left and vec_right */ + int32_t amount_ge_pivot = + partition_vec(arr, l_store, r_store + vtype::numlanes, vec_left, + pivot_vec, &min_vec, &max_vec, use_gt); + l_store += (vtype::numlanes - amount_ge_pivot); + amount_ge_pivot = + partition_vec(arr, l_store, l_store + vtype::numlanes, vec_right, + pivot_vec, &min_vec, &max_vec, use_gt); + l_store += (vtype::numlanes - amount_ge_pivot); + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return l_store; +} + +template +static inline int64_t partition_avx512_unrolled(type_t *arr, int64_t left, + int64_t right, type_t pivot, + type_t *smallest, + type_t *biggest, bool use_gt) { + if (right - left <= 2 * num_unroll * vtype::numlanes) { + return partition_avx512(arr, left, right, pivot, smallest, + biggest, use_gt); + } + + auto comparison_func = use_gt ? comparison_func_gt : comparison_func_ge; + /* make array length divisible by 8*vtype::numlanes , shortening the array + */ + for (int32_t i = ((right - left) % (num_unroll * vtype::numlanes)); i > 0; + --i) { + *smallest = std::min(*smallest, arr[left], comparison_func); + *biggest = std::max(*biggest, arr[left], comparison_func); + if (!comparison_func(arr[left], pivot)) { + std::swap(arr[left], arr[--right]); + } else { + ++left; + } + } + + if (left == right) + return left; /* less than vtype::numlanes elements in the array */ + + using zmm_t = typename vtype::zmm_t; + zmm_t pivot_vec = vtype::set1(pivot); + zmm_t min_vec = vtype::set1(*smallest); + zmm_t max_vec = vtype::set1(*biggest); + + // We will now have atleast 16 registers worth of data to process: + // left and right vtype::numlanes values are partitioned at the end + zmm_t vec_left[num_unroll], vec_right[num_unroll]; +#pragma GCC unroll 8 + for (int ii = 0; ii < num_unroll; ++ii) { + vec_left[ii] = vtype::loadu(arr + left + vtype::numlanes * ii); + vec_right[ii] = + vtype::loadu(arr + (right - vtype::numlanes * (num_unroll - ii))); + } + // store points of the vectors + int64_t r_store = right - vtype::numlanes; + int64_t l_store = left; + // indices for loading the elements + left += num_unroll * vtype::numlanes; + right -= num_unroll * vtype::numlanes; + while (right - left != 0) { + zmm_t curr_vec[num_unroll]; + /* + * if fewer elements are stored on the right side of the array, + * then next elements are loaded from the right side, + * otherwise from the left side + */ + if ((r_store + vtype::numlanes) - right < left - l_store) { + right -= num_unroll * vtype::numlanes; +#pragma GCC unroll 8 + for (int ii = 0; ii < num_unroll; ++ii) { + curr_vec[ii] = vtype::loadu(arr + right + ii * vtype::numlanes); + } + } else { +#pragma GCC unroll 8 + for (int ii = 0; ii < num_unroll; ++ii) { + curr_vec[ii] = vtype::loadu(arr + left + ii * vtype::numlanes); + } + left += num_unroll * vtype::numlanes; + } +// partition the current vector and save it on both sides of the array +#pragma GCC unroll 8 + for (int ii = 0; ii < num_unroll; ++ii) { + int32_t amount_ge_pivot = partition_vec( + arr, l_store, r_store + vtype::numlanes, curr_vec[ii], + pivot_vec, &min_vec, &max_vec, use_gt); + l_store += (vtype::numlanes - amount_ge_pivot); + r_store -= amount_ge_pivot; + } + } + +/* partition and save vec_left[8] and vec_right[8] */ +#pragma GCC unroll 8 + for (int ii = 0; ii < num_unroll; ++ii) { + int32_t amount_ge_pivot = + partition_vec(arr, l_store, r_store + vtype::numlanes, + vec_left[ii], pivot_vec, &min_vec, &max_vec, use_gt); + l_store += (vtype::numlanes - amount_ge_pivot); + r_store -= amount_ge_pivot; + } +#pragma GCC unroll 8 + for (int ii = 0; ii < num_unroll; ++ii) { + int32_t amount_ge_pivot = + partition_vec(arr, l_store, r_store + vtype::numlanes, + vec_right[ii], pivot_vec, &min_vec, &max_vec, use_gt); + l_store += (vtype::numlanes - amount_ge_pivot); + r_store -= amount_ge_pivot; + } + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return l_store; +} + +// to_index (exclusive) +template +static int64_t vectorized_partition(type_t *arr, int64_t from_index, int64_t to_index, type_t pivot, bool use_gt) { + type_t smallest = vtype::type_max(); + type_t biggest = vtype::type_min(); + int64_t pivot_index = partition_avx512_unrolled( + arr, from_index, to_index, pivot, &smallest, &biggest, use_gt); + return pivot_index; +} + +// partitioning functions +template +void avx512_dual_pivot_partition(T *arr, int64_t from_index, int64_t to_index, int32_t *pivot_indices, int64_t index_pivot1, int64_t index_pivot2){ + const T pivot1 = arr[index_pivot1]; + const T pivot2 = arr[index_pivot2]; + + const int64_t low = from_index; + const int64_t high = to_index; + const int64_t start = low + 1; + const int64_t end = high - 1; + + + std::swap(arr[index_pivot1], arr[low]); + std::swap(arr[index_pivot2], arr[end]); + + + const int64_t pivot_index2 = vectorized_partition, T>(arr, start, end, pivot2, true); // use_gt = true + std::swap(arr[end], arr[pivot_index2]); + int64_t upper = pivot_index2; + + // if all other elements are greater than pivot2 (and pivot1), no need to do further partitioning + if (upper == start) { + pivot_indices[0] = low; + pivot_indices[1] = upper; + return; + } + + const int64_t pivot_index1 = vectorized_partition, T>(arr, start, upper, pivot1, false); // use_ge (use_gt = false) + int64_t lower = pivot_index1 - 1; + std::swap(arr[low], arr[lower]); + + pivot_indices[0] = lower; + pivot_indices[1] = upper; +} + +template +void avx512_single_pivot_partition(T *arr, int64_t from_index, int64_t to_index, int32_t *pivot_indices, int64_t index_pivot){ + const T pivot = arr[index_pivot]; + + const int64_t low = from_index; + const int64_t high = to_index; + const int64_t end = high - 1; + + + const int64_t pivot_index1 = vectorized_partition, T>(arr, low, high, pivot, false); // use_gt = false (use_ge) + int64_t lower = pivot_index1; + + const int64_t pivot_index2 = vectorized_partition, T>(arr, pivot_index1, high, pivot, true); // use_gt = true + int64_t upper = pivot_index2; + + pivot_indices[0] = lower; + pivot_indices[1] = upper; +} + +template +void inline avx512_fast_partition(T *arr, int64_t from_index, int64_t to_index, int32_t *pivot_indices, int64_t index_pivot1, int64_t index_pivot2) { + if (index_pivot1 != index_pivot2) { + avx512_dual_pivot_partition(arr, from_index, to_index, pivot_indices, index_pivot1, index_pivot2); + } + else { + avx512_single_pivot_partition(arr, from_index, to_index, pivot_indices, index_pivot1); + } +} + +template +void inline insertion_sort(T *arr, int32_t from_index, int32_t to_index) { + for (int i, k = from_index; ++k < to_index; ) { + T ai = arr[i = k]; + + if (ai < arr[i - 1]) { + while (--i >= from_index && ai < arr[i]) { + arr[i + 1] = arr[i]; + } + arr[i + 1] = ai; + } + } +} + +template +void inline avx512_fast_sort(T *arr, int64_t from_index, int64_t to_index, const int32_t INS_SORT_THRESHOLD) { + int32_t size = to_index - from_index; + + if (size <= INS_SORT_THRESHOLD) { + insertion_sort(arr, from_index, to_index); + } + else { + avx512_qsort(arr, from_index, to_index); + } +} + + + +#endif // AVX512_QSORT_COMMON diff --git a/src/java.base/linux/native/libsimdsort/avx512-linux-qsort.cpp b/src/java.base/linux/native/libsimdsort/avx512-linux-qsort.cpp new file mode 100644 index 0000000000000..6bd0c5871d6cb --- /dev/null +++ b/src/java.base/linux/native/libsimdsort/avx512-linux-qsort.cpp @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2023 Intel Corporation. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + * + */ + +#pragma GCC target("avx512dq", "avx512f") +#include "avx512-32bit-qsort.hpp" +#include "avx512-64bit-qsort.hpp" +#include "classfile_constants.h" + +#define DLL_PUBLIC __attribute__((visibility("default"))) +#define INSERTION_SORT_THRESHOLD_32BIT 16 +#define INSERTION_SORT_THRESHOLD_64BIT 20 + +extern "C" { + + DLL_PUBLIC void avx512_sort(void *array, int elem_type, int32_t from_index, int32_t to_index) { + switch(elem_type) { + case JVM_T_INT: + avx512_fast_sort((int32_t*)array, from_index, to_index, INSERTION_SORT_THRESHOLD_32BIT); + break; + case JVM_T_LONG: + avx512_fast_sort((int64_t*)array, from_index, to_index, INSERTION_SORT_THRESHOLD_64BIT); + break; + case JVM_T_FLOAT: + avx512_fast_sort((float*)array, from_index, to_index, INSERTION_SORT_THRESHOLD_32BIT); + break; + case JVM_T_DOUBLE: + avx512_fast_sort((double*)array, from_index, to_index, INSERTION_SORT_THRESHOLD_64BIT); + break; + } + } + + DLL_PUBLIC void avx512_partition(void *array, int elem_type, int32_t from_index, int32_t to_index, int32_t *pivot_indices, int32_t index_pivot1, int32_t index_pivot2) { + switch(elem_type) { + case JVM_T_INT: + avx512_fast_partition((int32_t*)array, from_index, to_index, pivot_indices, index_pivot1, index_pivot2); + break; + case JVM_T_LONG: + avx512_fast_partition((int64_t*)array, from_index, to_index, pivot_indices, index_pivot1, index_pivot2); + break; + case JVM_T_FLOAT: + avx512_fast_partition((float*)array, from_index, to_index, pivot_indices, index_pivot1, index_pivot2); + break; + case JVM_T_DOUBLE: + avx512_fast_partition((double*)array, from_index, to_index, pivot_indices, index_pivot1, index_pivot2); + break; + } + } + +} diff --git a/src/java.base/share/classes/java/util/DualPivotQuicksort.java b/src/java.base/share/classes/java/util/DualPivotQuicksort.java index 3dcc7fee1f525..0dd4b6e354aed 100644 --- a/src/java.base/share/classes/java/util/DualPivotQuicksort.java +++ b/src/java.base/share/classes/java/util/DualPivotQuicksort.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2009, 2021, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2009, 2023, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -27,6 +27,9 @@ import java.util.concurrent.CountedCompleter; import java.util.concurrent.RecursiveTask; +import jdk.internal.misc.Unsafe; +import jdk.internal.vm.annotation.IntrinsicCandidate; +import jdk.internal.vm.annotation.ForceInline; /** * This class implements powerful and fully optimized versions, both @@ -120,6 +123,78 @@ private DualPivotQuicksort() {} */ private static final int MAX_RECURSION_DEPTH = 64 * DELTA; + /** + * Represents a function that accepts the array and sorts the specified range + * of the array into ascending order. + */ + @FunctionalInterface + private static interface SortOperation { + /** + * Sorts the specified range of the array. + * + * @param a the array to be sorted + * @param low the index of the first element, inclusive, to be sorted + * @param high the index of the last element, exclusive, to be sorted + */ + void sort(A a, int low, int high); + } + + /** + * Sorts the specified range of the array into ascending numerical order. + * + * @param elemType the class of the elements of the array to be sorted + * @param array the array to be sorted + * @param offset the relative offset, in bytes, from the base address of + * the array to sort, otherwise if the array is {@code null},an absolute + * address pointing to the first element to sort from. + * @param low the index of the first element, inclusive, to be sorted + * @param high the index of the last element, exclusive, to be sorted + * @param so the method reference for the fallback implementation + */ + @IntrinsicCandidate + @ForceInline + private static void sort(Class elemType, A array, long offset, int low, int high, SortOperation so) { + so.sort(array, low, high); + } + + /** + * Represents a function that accepts the array and partitions the specified range + * of the array using the pivots provided. + */ + @FunctionalInterface + interface PartitionOperation { + /** + * Partitions the specified range of the array using the given pivots. + * + * @param a the array to be sorted + * @param low the index of the first element, inclusive, to be sorted + * @param high the index of the last element, exclusive, to be sorted + * @param pivotIndex1 the index of pivot1, the first pivot + * @param pivotIndex2 the index of pivot2, the second pivot + */ + int[] partition(A a, int low, int high, int pivotIndex1, int pivotIndex2); + } + + /** + * Partitions the specified range of the array using the two pivots provided. + * + * @param elemType the class of the array to be sorted + * @param array the array to be sorted + * @param offset the relative offset, in bytes, from the base address of + * the array to partition, otherwise if the array is {@code null},an absolute + * address pointing to the first element to partition from. + * @param low the index of the first element, inclusive, to be sorted + * @param high the index of the last element, exclusive, to be sorted + * @param pivotIndex1 the index of pivot1, the first pivot + * @param pivotIndex2 the index of pivot2, the second pivot + * @param po the method reference for the fallback implementation + */ + @IntrinsicCandidate + @ForceInline + private static int[] partition(Class elemType, A array, long offset, int low, int high, int pivotIndex1, int pivotIndex2, PartitionOperation po) { + return po.partition(array, low, high, pivotIndex1, pivotIndex2); + } + /** * Calculates the double depth of parallel merging. * Depth is negative, if tasks split before sorting. @@ -178,12 +253,11 @@ static void sort(int[] a, int parallelism, int low, int high) { static void sort(Sorter sorter, int[] a, int bits, int low, int high) { while (true) { int end = high - 1, size = high - low; - /* * Run mixed insertion sort on small non-leftmost parts. */ if (size < MAX_MIXED_INSERTION_SORT_SIZE + bits && (bits & 1) > 0) { - mixedInsertionSort(a, low, high - 3 * ((size >> 5) << 3), high); + sort(int.class, a, Unsafe.ARRAY_INT_BASE_OFFSET, low, high, DualPivotQuicksort::mixedInsertionSort); return; } @@ -191,7 +265,7 @@ static void sort(Sorter sorter, int[] a, int bits, int low, int high) { * Invoke insertion sort on small leftmost part. */ if (size < MAX_INSERTION_SORT_SIZE) { - insertionSort(a, low, high); + sort(int.class, a, Unsafe.ARRAY_INT_BASE_OFFSET, low, high, DualPivotQuicksort::insertionSort); return; } @@ -265,84 +339,23 @@ && tryMergeRuns(sorter, a, low, size)) { } // Pointers - int lower = low; // The index of the last element of the left part - int upper = end; // The index of the first element of the right part + int lower; // The index of the last element of the left part + int upper; // The index of the first element of the right part /* * Partitioning with 2 pivots in case of different elements. */ if (a[e1] < a[e2] && a[e2] < a[e3] && a[e3] < a[e4] && a[e4] < a[e5]) { - /* * Use the first and fifth of the five sorted elements as * the pivots. These values are inexpensive approximation * of tertiles. Note, that pivot1 < pivot2. */ - int pivot1 = a[e1]; - int pivot2 = a[e5]; - - /* - * The first and the last elements to be sorted are moved - * to the locations formerly occupied by the pivots. When - * partitioning is completed, the pivots are swapped back - * into their final positions, and excluded from the next - * subsequent sorting. - */ - a[e1] = a[lower]; - a[e5] = a[upper]; - - /* - * Skip elements, which are less or greater than the pivots. - */ - while (a[++lower] < pivot1); - while (a[--upper] > pivot2); - - /* - * Backward 3-interval partitioning - * - * left part central part right part - * +------------------------------------------------------------+ - * | < pivot1 | ? | pivot1 <= && <= pivot2 | > pivot2 | - * +------------------------------------------------------------+ - * ^ ^ ^ - * | | | - * lower k upper - * - * Invariants: - * - * all in (low, lower] < pivot1 - * pivot1 <= all in (k, upper) <= pivot2 - * all in [upper, end) > pivot2 - * - * Pointer k is the last index of ?-part - */ - for (int unused = --lower, k = ++upper; --k > lower; ) { - int ak = a[k]; + int[] pivotIndices = partition(int.class, a, Unsafe.ARRAY_INT_BASE_OFFSET, low, high, e1, e5, DualPivotQuicksort::partitionDualPivot); + lower = pivotIndices[0]; + upper = pivotIndices[1]; - if (ak < pivot1) { // Move a[k] to the left side - while (lower < k) { - if (a[++lower] >= pivot1) { - if (a[lower] > pivot2) { - a[k] = a[--upper]; - a[upper] = a[lower]; - } else { - a[k] = a[lower]; - } - a[lower] = ak; - break; - } - } - } else if (ak > pivot2) { // Move a[k] to the right side - a[k] = a[--upper]; - a[upper] = ak; - } - } - /* - * Swap the pivots into their final positions. - */ - a[low] = a[lower]; a[lower] = pivot1; - a[end] = a[upper]; a[upper] = pivot2; /* * Sort non-left parts recursively (possibly in parallel), @@ -362,73 +375,182 @@ && tryMergeRuns(sorter, a, low, size)) { * Use the third of the five sorted elements as the pivot. * This value is inexpensive approximation of the median. */ - int pivot = a[e3]; - + int[] pivotIndices = partition(int.class, a, Unsafe.ARRAY_INT_BASE_OFFSET, low, high, e3, e3, DualPivotQuicksort::partitionSinglePivot); + lower = pivotIndices[0]; + upper = pivotIndices[1]; /* - * The first element to be sorted is moved to the - * location formerly occupied by the pivot. After - * completion of partitioning the pivot is swapped - * back into its final position, and excluded from - * the next subsequent sorting. + * Sort the right part (possibly in parallel), excluding + * known pivot. All elements from the central part are + * equal and therefore already sorted. */ - a[e3] = a[lower]; + if (size > MIN_PARALLEL_SORT_SIZE && sorter != null) { + sorter.forkSorter(bits | 1, upper, high); + } else { + sort(sorter, a, bits | 1, upper, high); + } + } + high = lower; // Iterate along the left part + } + } - /* - * Traditional 3-way (Dutch National Flag) partitioning - * - * left part central part right part - * +------------------------------------------------------+ - * | < pivot | ? | == pivot | > pivot | - * +------------------------------------------------------+ - * ^ ^ ^ - * | | | - * lower k upper - * - * Invariants: - * - * all in (low, lower] < pivot - * all in (k, upper) == pivot - * all in [upper, end] > pivot - * - * Pointer k is the last index of ?-part - */ - for (int k = ++upper; --k > lower; ) { - int ak = a[k]; + /** + * Partitions the specified range of the array using the two pivots provided. + * + * @param array the array to be partitioned + * @param low the index of the first element, inclusive, for partitioning + * @param high the index of the last element, exclusive, for partitioning + * @param pivotIndex1 the index of pivot1, the first pivot + * @param pivotIndex2 the index of pivot2, the second pivot + * + */ + @ForceInline + private static int[] partitionDualPivot(int[] a, int low, int high, int pivotIndex1, int pivotIndex2) { + int end = high - 1; + int lower = low; + int upper = end; - if (ak != pivot) { - a[k] = pivot; + int e1 = pivotIndex1; + int e5 = pivotIndex2; + int pivot1 = a[e1]; + int pivot2 = a[e5]; - if (ak < pivot) { // Move a[k] to the left side - while (a[++lower] < pivot); + /* + * The first and the last elements to be sorted are moved + * to the locations formerly occupied by the pivots. When + * partitioning is completed, the pivots are swapped back + * into their final positions, and excluded from the next + * subsequent sorting. + */ + a[e1] = a[lower]; + a[e5] = a[upper]; - if (a[lower] > pivot) { - a[--upper] = a[lower]; - } - a[lower] = ak; - } else { // ak > pivot - Move a[k] to the right side - a[--upper] = ak; + /* + * Skip elements, which are less or greater than the pivots. + */ + while (a[++lower] < pivot1); + while (a[--upper] > pivot2); + + /* + * Backward 3-interval partitioning + * + * left part central part right part + * +------------------------------------------------------------+ + * | < pivot1 | ? | pivot1 <= && <= pivot2 | > pivot2 | + * +------------------------------------------------------------+ + * ^ ^ ^ + * | | | + * lower k upper + * + * Invariants: + * + * all in (low, lower] < pivot1 + * pivot1 <= all in (k, upper) <= pivot2 + * all in [upper, end) > pivot2 + * + * Pointer k is the last index of ?-part + */ + for (int unused = --lower, k = ++upper; --k > lower; ) { + int ak = a[k]; + + if (ak < pivot1) { // Move a[k] to the left side + while (lower < k) { + if (a[++lower] >= pivot1) { + if (a[lower] > pivot2) { + a[k] = a[--upper]; + a[upper] = a[lower]; + } else { + a[k] = a[lower]; } + a[lower] = ak; + break; } } + } else if (ak > pivot2) { // Move a[k] to the right side + a[k] = a[--upper]; + a[upper] = ak; + } + } - /* - * Swap the pivot into its final position. - */ - a[low] = a[lower]; a[lower] = pivot; + /* + * Swap the pivots into their final positions. + */ + a[low] = a[lower]; a[lower] = pivot1; + a[end] = a[upper]; a[upper] = pivot2; - /* - * Sort the right part (possibly in parallel), excluding - * known pivot. All elements from the central part are - * equal and therefore already sorted. - */ - if (size > MIN_PARALLEL_SORT_SIZE && sorter != null) { - sorter.forkSorter(bits | 1, upper, high); - } else { - sort(sorter, a, bits | 1, upper, high); + return new int[] {lower, upper}; + } + + /** + * Partitions the specified range of the array using a single pivot provided. + * + * @param array the array to be partitioned + * @param low the index of the first element, inclusive, for partitioning + * @param high the index of the last element, exclusive, for partitioning + * @param pivotIndex1 the index of pivot1, the first pivot + * @param pivotIndex2 the index of pivot2, the second pivot + * + */ + @ForceInline + private static int[] partitionSinglePivot(int[] a, int low, int high, int pivotIndex1, int pivotIndex2) { + + int end = high - 1; + int lower = low; + int upper = end; + int e3 = pivotIndex1; + int pivot = a[e3]; + + /* + * The first element to be sorted is moved to the + * location formerly occupied by the pivot. After + * completion of partitioning the pivot is swapped + * back into its final position, and excluded from + * the next subsequent sorting. + */ + a[e3] = a[lower]; + + /* + * Traditional 3-way (Dutch National Flag) partitioning + * + * left part central part right part + * +------------------------------------------------------+ + * | < pivot | ? | == pivot | > pivot | + * +------------------------------------------------------+ + * ^ ^ ^ + * | | | + * lower k upper + * + * Invariants: + * + * all in (low, lower] < pivot + * all in (k, upper) == pivot + * all in [upper, end] > pivot + * + * Pointer k is the last index of ?-part + */ + for (int k = ++upper; --k > lower; ) { + int ak = a[k]; + + if (ak != pivot) { + a[k] = pivot; + + if (ak < pivot) { // Move a[k] to the left side + while (a[++lower] < pivot); + + if (a[lower] > pivot) { + a[--upper] = a[lower]; + } + a[lower] = ak; + } else { // ak > pivot - Move a[k] to the right side + a[--upper] = ak; } } - high = lower; // Iterate along the left part } + + /* + * Swap the pivot into its final position. + */ + a[low] = a[lower]; a[lower] = pivot; + return new int[] {lower, upper}; } /** @@ -445,10 +567,11 @@ && tryMergeRuns(sorter, a, low, size)) { * * @param a the array to be sorted * @param low the index of the first element, inclusive, to be sorted - * @param end the index of the last element for simple insertion sort * @param high the index of the last element, exclusive, to be sorted */ - private static void mixedInsertionSort(int[] a, int low, int end, int high) { + private static void mixedInsertionSort(int[] a, int low, int high) { + int size = high - low; + int end = high - 3 * ((size >> 5) << 3); if (end == high) { /* @@ -937,7 +1060,7 @@ static void sort(Sorter sorter, long[] a, int bits, int low, int high) { * Run mixed insertion sort on small non-leftmost parts. */ if (size < MAX_MIXED_INSERTION_SORT_SIZE + bits && (bits & 1) > 0) { - mixedInsertionSort(a, low, high - 3 * ((size >> 5) << 3), high); + sort(long.class, a, Unsafe.ARRAY_LONG_BASE_OFFSET, low, high, DualPivotQuicksort::mixedInsertionSort); return; } @@ -945,7 +1068,7 @@ static void sort(Sorter sorter, long[] a, int bits, int low, int high) { * Invoke insertion sort on small leftmost part. */ if (size < MAX_INSERTION_SORT_SIZE) { - insertionSort(a, low, high); + sort(long.class, a, Unsafe.ARRAY_LONG_BASE_OFFSET, low, high, DualPivotQuicksort::insertionSort); return; } @@ -1019,8 +1142,8 @@ && tryMergeRuns(sorter, a, low, size)) { } // Pointers - int lower = low; // The index of the last element of the left part - int upper = end; // The index of the first element of the right part + int lower; // The index of the last element of the left part + int upper; // The index of the first element of the right part /* * Partitioning with 2 pivots in case of different elements. @@ -1032,72 +1155,9 @@ && tryMergeRuns(sorter, a, low, size)) { * the pivots. These values are inexpensive approximation * of tertiles. Note, that pivot1 < pivot2. */ - long pivot1 = a[e1]; - long pivot2 = a[e5]; - - /* - * The first and the last elements to be sorted are moved - * to the locations formerly occupied by the pivots. When - * partitioning is completed, the pivots are swapped back - * into their final positions, and excluded from the next - * subsequent sorting. - */ - a[e1] = a[lower]; - a[e5] = a[upper]; - - /* - * Skip elements, which are less or greater than the pivots. - */ - while (a[++lower] < pivot1); - while (a[--upper] > pivot2); - - /* - * Backward 3-interval partitioning - * - * left part central part right part - * +------------------------------------------------------------+ - * | < pivot1 | ? | pivot1 <= && <= pivot2 | > pivot2 | - * +------------------------------------------------------------+ - * ^ ^ ^ - * | | | - * lower k upper - * - * Invariants: - * - * all in (low, lower] < pivot1 - * pivot1 <= all in (k, upper) <= pivot2 - * all in [upper, end) > pivot2 - * - * Pointer k is the last index of ?-part - */ - for (int unused = --lower, k = ++upper; --k > lower; ) { - long ak = a[k]; - - if (ak < pivot1) { // Move a[k] to the left side - while (lower < k) { - if (a[++lower] >= pivot1) { - if (a[lower] > pivot2) { - a[k] = a[--upper]; - a[upper] = a[lower]; - } else { - a[k] = a[lower]; - } - a[lower] = ak; - break; - } - } - } else if (ak > pivot2) { // Move a[k] to the right side - a[k] = a[--upper]; - a[upper] = ak; - } - } - - /* - * Swap the pivots into their final positions. - */ - a[low] = a[lower]; a[lower] = pivot1; - a[end] = a[upper]; a[upper] = pivot2; - + int[] pivotIndices = partition(long.class, a, Unsafe.ARRAY_LONG_BASE_OFFSET, low, high, e1, e5, DualPivotQuicksort::partitionDualPivot); + lower = pivotIndices[0]; + upper = pivotIndices[1]; /* * Sort non-left parts recursively (possibly in parallel), * excluding known pivots. @@ -1116,73 +1176,183 @@ && tryMergeRuns(sorter, a, low, size)) { * Use the third of the five sorted elements as the pivot. * This value is inexpensive approximation of the median. */ - long pivot = a[e3]; - + int[] pivotIndices = partition(long.class, a, Unsafe.ARRAY_LONG_BASE_OFFSET, low, high, e3, e3, DualPivotQuicksort::partitionSinglePivot); + lower = pivotIndices[0]; + upper = pivotIndices[1]; /* - * The first element to be sorted is moved to the - * location formerly occupied by the pivot. After - * completion of partitioning the pivot is swapped - * back into its final position, and excluded from - * the next subsequent sorting. + * Sort the right part (possibly in parallel), excluding + * known pivot. All elements from the central part are + * equal and therefore already sorted. */ - a[e3] = a[lower]; + if (size > MIN_PARALLEL_SORT_SIZE && sorter != null) { + sorter.forkSorter(bits | 1, upper, high); + } else { + sort(sorter, a, bits | 1, upper, high); + } + } + high = lower; // Iterate along the left part + } + } - /* - * Traditional 3-way (Dutch National Flag) partitioning - * - * left part central part right part - * +------------------------------------------------------+ - * | < pivot | ? | == pivot | > pivot | - * +------------------------------------------------------+ - * ^ ^ ^ - * | | | - * lower k upper - * - * Invariants: - * - * all in (low, lower] < pivot - * all in (k, upper) == pivot - * all in [upper, end] > pivot - * - * Pointer k is the last index of ?-part - */ - for (int k = ++upper; --k > lower; ) { - long ak = a[k]; + /** + * Partitions the specified range of the array using the two pivots provided. + * + * @param array the array to be partitioned + * @param low the index of the first element, inclusive, for partitioning + * @param high the index of the last element, exclusive, for partitioning + * @param pivotIndex1 the index of pivot1, the first pivot + * @param pivotIndex2 the index of pivot2, the second pivot + * + */ + @ForceInline + private static int[] partitionDualPivot(long[] a, int low, int high, int pivotIndex1, int pivotIndex2) { + int end = high - 1; + int lower = low; + int upper = end; - if (ak != pivot) { - a[k] = pivot; + int e1 = pivotIndex1; + int e5 = pivotIndex2; + long pivot1 = a[e1]; + long pivot2 = a[e5]; - if (ak < pivot) { // Move a[k] to the left side - while (a[++lower] < pivot); + /* + * The first and the last elements to be sorted are moved + * to the locations formerly occupied by the pivots. When + * partitioning is completed, the pivots are swapped back + * into their final positions, and excluded from the next + * subsequent sorting. + */ + a[e1] = a[lower]; + a[e5] = a[upper]; - if (a[lower] > pivot) { - a[--upper] = a[lower]; - } - a[lower] = ak; - } else { // ak > pivot - Move a[k] to the right side - a[--upper] = ak; + /* + * Skip elements, which are less or greater than the pivots. + */ + while (a[++lower] < pivot1); + while (a[--upper] > pivot2); + + /* + * Backward 3-interval partitioning + * + * left part central part right part + * +------------------------------------------------------------+ + * | < pivot1 | ? | pivot1 <= && <= pivot2 | > pivot2 | + * +------------------------------------------------------------+ + * ^ ^ ^ + * | | | + * lower k upper + * + * Invariants: + * + * all in (low, lower] < pivot1 + * pivot1 <= all in (k, upper) <= pivot2 + * all in [upper, end) > pivot2 + * + * Pointer k is the last index of ?-part + */ + for (int unused = --lower, k = ++upper; --k > lower; ) { + long ak = a[k]; + + if (ak < pivot1) { // Move a[k] to the left side + while (lower < k) { + if (a[++lower] >= pivot1) { + if (a[lower] > pivot2) { + a[k] = a[--upper]; + a[upper] = a[lower]; + } else { + a[k] = a[lower]; } + a[lower] = ak; + break; } } + } else if (ak > pivot2) { // Move a[k] to the right side + a[k] = a[--upper]; + a[upper] = ak; + } + } - /* - * Swap the pivot into its final position. - */ - a[low] = a[lower]; a[lower] = pivot; + /* + * Swap the pivots into their final positions. + */ + a[low] = a[lower]; a[lower] = pivot1; + a[end] = a[upper]; a[upper] = pivot2; - /* - * Sort the right part (possibly in parallel), excluding - * known pivot. All elements from the central part are - * equal and therefore already sorted. - */ - if (size > MIN_PARALLEL_SORT_SIZE && sorter != null) { - sorter.forkSorter(bits | 1, upper, high); - } else { - sort(sorter, a, bits | 1, upper, high); + return new int[] {lower, upper}; + } + + /** + * Partitions the specified range of the array using a single pivot provided. + * + * @param array the array to be partitioned + * @param low the index of the first element, inclusive, for partitioning + * @param high the index of the last element, exclusive, for partitioning + * @param pivotIndex1 the index of pivot1, the first pivot + * @param pivotIndex2 the index of pivot2, the second pivot + * + */ + @ForceInline + private static int[] partitionSinglePivot(long[] a, int low, int high, int pivotIndex1, int pivotIndex2) { + + int end = high - 1; + int lower = low; + int upper = end; + + int e3 = pivotIndex1; + long pivot = a[e3]; + + /* + * The first element to be sorted is moved to the + * location formerly occupied by the pivot. After + * completion of partitioning the pivot is swapped + * back into its final position, and excluded from + * the next subsequent sorting. + */ + a[e3] = a[lower]; + + /* + * Traditional 3-way (Dutch National Flag) partitioning + * + * left part central part right part + * +------------------------------------------------------+ + * | < pivot | ? | == pivot | > pivot | + * +------------------------------------------------------+ + * ^ ^ ^ + * | | | + * lower k upper + * + * Invariants: + * + * all in (low, lower] < pivot + * all in (k, upper) == pivot + * all in [upper, end] > pivot + * + * Pointer k is the last index of ?-part + */ + for (int k = ++upper; --k > lower; ) { + long ak = a[k]; + + if (ak != pivot) { + a[k] = pivot; + + if (ak < pivot) { // Move a[k] to the left side + while (a[++lower] < pivot); + + if (a[lower] > pivot) { + a[--upper] = a[lower]; + } + a[lower] = ak; + } else { // ak > pivot - Move a[k] to the right side + a[--upper] = ak; } } - high = lower; // Iterate along the left part } + + /* + * Swap the pivot into its final position. + */ + a[low] = a[lower]; a[lower] = pivot; + return new int[] {lower, upper}; } /** @@ -1199,10 +1369,11 @@ && tryMergeRuns(sorter, a, low, size)) { * * @param a the array to be sorted * @param low the index of the first element, inclusive, to be sorted - * @param end the index of the last element for simple insertion sort * @param high the index of the last element, exclusive, to be sorted */ - private static void mixedInsertionSort(long[] a, int low, int end, int high) { + private static void mixedInsertionSort(long[] a, int low, int high) { + int size = high - low; + int end = high - 3 * ((size >> 5) << 3); if (end == high) { /* @@ -2478,7 +2649,7 @@ static void sort(Sorter sorter, float[] a, int bits, int low, int high) { * Run mixed insertion sort on small non-leftmost parts. */ if (size < MAX_MIXED_INSERTION_SORT_SIZE + bits && (bits & 1) > 0) { - mixedInsertionSort(a, low, high - 3 * ((size >> 5) << 3), high); + sort(float.class, a, Unsafe.ARRAY_FLOAT_BASE_OFFSET, low, high, DualPivotQuicksort::mixedInsertionSort); return; } @@ -2486,7 +2657,7 @@ static void sort(Sorter sorter, float[] a, int bits, int low, int high) { * Invoke insertion sort on small leftmost part. */ if (size < MAX_INSERTION_SORT_SIZE) { - insertionSort(a, low, high); + sort(float.class, a, Unsafe.ARRAY_FLOAT_BASE_OFFSET, low, high, DualPivotQuicksort::insertionSort); return; } @@ -2560,8 +2731,8 @@ && tryMergeRuns(sorter, a, low, size)) { } // Pointers - int lower = low; // The index of the last element of the left part - int upper = end; // The index of the first element of the right part + int lower; // The index of the last element of the left part + int upper; // The index of the first element of the right part /* * Partitioning with 2 pivots in case of different elements. @@ -2573,72 +2744,9 @@ && tryMergeRuns(sorter, a, low, size)) { * the pivots. These values are inexpensive approximation * of tertiles. Note, that pivot1 < pivot2. */ - float pivot1 = a[e1]; - float pivot2 = a[e5]; - - /* - * The first and the last elements to be sorted are moved - * to the locations formerly occupied by the pivots. When - * partitioning is completed, the pivots are swapped back - * into their final positions, and excluded from the next - * subsequent sorting. - */ - a[e1] = a[lower]; - a[e5] = a[upper]; - - /* - * Skip elements, which are less or greater than the pivots. - */ - while (a[++lower] < pivot1); - while (a[--upper] > pivot2); - - /* - * Backward 3-interval partitioning - * - * left part central part right part - * +------------------------------------------------------------+ - * | < pivot1 | ? | pivot1 <= && <= pivot2 | > pivot2 | - * +------------------------------------------------------------+ - * ^ ^ ^ - * | | | - * lower k upper - * - * Invariants: - * - * all in (low, lower] < pivot1 - * pivot1 <= all in (k, upper) <= pivot2 - * all in [upper, end) > pivot2 - * - * Pointer k is the last index of ?-part - */ - for (int unused = --lower, k = ++upper; --k > lower; ) { - float ak = a[k]; - - if (ak < pivot1) { // Move a[k] to the left side - while (lower < k) { - if (a[++lower] >= pivot1) { - if (a[lower] > pivot2) { - a[k] = a[--upper]; - a[upper] = a[lower]; - } else { - a[k] = a[lower]; - } - a[lower] = ak; - break; - } - } - } else if (ak > pivot2) { // Move a[k] to the right side - a[k] = a[--upper]; - a[upper] = ak; - } - } - - /* - * Swap the pivots into their final positions. - */ - a[low] = a[lower]; a[lower] = pivot1; - a[end] = a[upper]; a[upper] = pivot2; - + int[] pivotIndices = partition(float.class, a, Unsafe.ARRAY_FLOAT_BASE_OFFSET, low, high, e1, e5, DualPivotQuicksort::partitionDualPivot); + lower = pivotIndices[0]; + upper = pivotIndices[1]; /* * Sort non-left parts recursively (possibly in parallel), * excluding known pivots. @@ -2657,73 +2765,182 @@ && tryMergeRuns(sorter, a, low, size)) { * Use the third of the five sorted elements as the pivot. * This value is inexpensive approximation of the median. */ - float pivot = a[e3]; - + int[] pivotIndices = partition(float.class, a, Unsafe.ARRAY_FLOAT_BASE_OFFSET, low, high, e3, e3, DualPivotQuicksort::partitionSinglePivot); + lower = pivotIndices[0]; + upper = pivotIndices[1]; /* - * The first element to be sorted is moved to the - * location formerly occupied by the pivot. After - * completion of partitioning the pivot is swapped - * back into its final position, and excluded from - * the next subsequent sorting. + * Sort the right part (possibly in parallel), excluding + * known pivot. All elements from the central part are + * equal and therefore already sorted. */ - a[e3] = a[lower]; + if (size > MIN_PARALLEL_SORT_SIZE && sorter != null) { + sorter.forkSorter(bits | 1, upper, high); + } else { + sort(sorter, a, bits | 1, upper, high); + } + } + high = lower; // Iterate along the left part + } + } - /* - * Traditional 3-way (Dutch National Flag) partitioning - * - * left part central part right part - * +------------------------------------------------------+ - * | < pivot | ? | == pivot | > pivot | - * +------------------------------------------------------+ - * ^ ^ ^ - * | | | - * lower k upper - * - * Invariants: - * - * all in (low, lower] < pivot - * all in (k, upper) == pivot - * all in [upper, end] > pivot - * - * Pointer k is the last index of ?-part - */ - for (int k = ++upper; --k > lower; ) { - float ak = a[k]; + /** + * Partitions the specified range of the array using the two pivots provided. + * + * @param array the array to be partitioned + * @param low the index of the first element, inclusive, for partitioning + * @param high the index of the last element, exclusive, for partitioning + * @param pivotIndex1 the index of pivot1, the first pivot + * @param pivotIndex2 the index of pivot2, the second pivot + * + */ + @ForceInline + private static int[] partitionDualPivot(float[] a, int low, int high, int pivotIndex1, int pivotIndex2) { + int end = high - 1; + int lower = low; + int upper = end; - if (ak != pivot) { - a[k] = pivot; + int e1 = pivotIndex1; + int e5 = pivotIndex2; + float pivot1 = a[e1]; + float pivot2 = a[e5]; - if (ak < pivot) { // Move a[k] to the left side - while (a[++lower] < pivot); + /* + * The first and the last elements to be sorted are moved + * to the locations formerly occupied by the pivots. When + * partitioning is completed, the pivots are swapped back + * into their final positions, and excluded from the next + * subsequent sorting. + */ + a[e1] = a[lower]; + a[e5] = a[upper]; - if (a[lower] > pivot) { - a[--upper] = a[lower]; - } - a[lower] = ak; - } else { // ak > pivot - Move a[k] to the right side - a[--upper] = ak; + /* + * Skip elements, which are less or greater than the pivots. + */ + while (a[++lower] < pivot1); + while (a[--upper] > pivot2); + + /* + * Backward 3-interval partitioning + * + * left part central part right part + * +------------------------------------------------------------+ + * | < pivot1 | ? | pivot1 <= && <= pivot2 | > pivot2 | + * +------------------------------------------------------------+ + * ^ ^ ^ + * | | | + * lower k upper + * + * Invariants: + * + * all in (low, lower] < pivot1 + * pivot1 <= all in (k, upper) <= pivot2 + * all in [upper, end) > pivot2 + * + * Pointer k is the last index of ?-part + */ + for (int unused = --lower, k = ++upper; --k > lower; ) { + float ak = a[k]; + + if (ak < pivot1) { // Move a[k] to the left side + while (lower < k) { + if (a[++lower] >= pivot1) { + if (a[lower] > pivot2) { + a[k] = a[--upper]; + a[upper] = a[lower]; + } else { + a[k] = a[lower]; } + a[lower] = ak; + break; } } + } else if (ak > pivot2) { // Move a[k] to the right side + a[k] = a[--upper]; + a[upper] = ak; + } + } - /* - * Swap the pivot into its final position. - */ - a[low] = a[lower]; a[lower] = pivot; + /* + * Swap the pivots into their final positions. + */ + a[low] = a[lower]; a[lower] = pivot1; + a[end] = a[upper]; a[upper] = pivot2; - /* - * Sort the right part (possibly in parallel), excluding - * known pivot. All elements from the central part are - * equal and therefore already sorted. - */ - if (size > MIN_PARALLEL_SORT_SIZE && sorter != null) { - sorter.forkSorter(bits | 1, upper, high); - } else { - sort(sorter, a, bits | 1, upper, high); + return new int[] {lower, upper}; + } + + /** + * Partitions the specified range of the array using a single pivot provided. + * + * @param array the array to be partitioned + * @param low the index of the first element, inclusive, for partitioning + * @param high the index of the last element, exclusive, for partitioning + * @param pivotIndex1 the index of pivot1, the first pivot + * @param pivotIndex2 the index of pivot2, the second pivot + * + */ + @ForceInline + private static int[] partitionSinglePivot(float[] a, int low, int high, int pivotIndex1, int pivotIndex2) { + int end = high - 1; + int lower = low; + int upper = end; + + int e3 = pivotIndex1; + float pivot = a[e3]; + + /* + * The first element to be sorted is moved to the + * location formerly occupied by the pivot. After + * completion of partitioning the pivot is swapped + * back into its final position, and excluded from + * the next subsequent sorting. + */ + a[e3] = a[lower]; + + /* + * Traditional 3-way (Dutch National Flag) partitioning + * + * left part central part right part + * +------------------------------------------------------+ + * | < pivot | ? | == pivot | > pivot | + * +------------------------------------------------------+ + * ^ ^ ^ + * | | | + * lower k upper + * + * Invariants: + * + * all in (low, lower] < pivot + * all in (k, upper) == pivot + * all in [upper, end] > pivot + * + * Pointer k is the last index of ?-part + */ + for (int k = ++upper; --k > lower; ) { + float ak = a[k]; + + if (ak != pivot) { + a[k] = pivot; + + if (ak < pivot) { // Move a[k] to the left side + while (a[++lower] < pivot); + + if (a[lower] > pivot) { + a[--upper] = a[lower]; + } + a[lower] = ak; + } else { // ak > pivot - Move a[k] to the right side + a[--upper] = ak; } } - high = lower; // Iterate along the left part } + + /* + * Swap the pivot into its final position. + */ + a[low] = a[lower]; a[lower] = pivot; + return new int[] {lower, upper}; } /** @@ -2740,10 +2957,11 @@ && tryMergeRuns(sorter, a, low, size)) { * * @param a the array to be sorted * @param low the index of the first element, inclusive, to be sorted - * @param end the index of the last element for simple insertion sort * @param high the index of the last element, exclusive, to be sorted */ - private static void mixedInsertionSort(float[] a, int low, int end, int high) { + private static void mixedInsertionSort(float[] a, int low, int high) { + int size = high - low; + int end = high - 3 * ((size >> 5) << 3); if (end == high) { /* @@ -3279,12 +3497,11 @@ static void sort(double[] a, int parallelism, int low, int high) { static void sort(Sorter sorter, double[] a, int bits, int low, int high) { while (true) { int end = high - 1, size = high - low; - /* * Run mixed insertion sort on small non-leftmost parts. */ if (size < MAX_MIXED_INSERTION_SORT_SIZE + bits && (bits & 1) > 0) { - mixedInsertionSort(a, low, high - 3 * ((size >> 5) << 3), high); + sort(double.class, a, Unsafe.ARRAY_DOUBLE_BASE_OFFSET, low, high, DualPivotQuicksort::mixedInsertionSort); return; } @@ -3292,7 +3509,7 @@ static void sort(Sorter sorter, double[] a, int bits, int low, int high) { * Invoke insertion sort on small leftmost part. */ if (size < MAX_INSERTION_SORT_SIZE) { - insertionSort(a, low, high); + sort(double.class, a, Unsafe.ARRAY_DOUBLE_BASE_OFFSET, low, high, DualPivotQuicksort::insertionSort); return; } @@ -3366,8 +3583,8 @@ && tryMergeRuns(sorter, a, low, size)) { } // Pointers - int lower = low; // The index of the last element of the left part - int upper = end; // The index of the first element of the right part + int lower; // The index of the last element of the left part + int upper; // The index of the first element of the right part /* * Partitioning with 2 pivots in case of different elements. @@ -3375,76 +3592,13 @@ && tryMergeRuns(sorter, a, low, size)) { if (a[e1] < a[e2] && a[e2] < a[e3] && a[e3] < a[e4] && a[e4] < a[e5]) { /* - * Use the first and fifth of the five sorted elements as - * the pivots. These values are inexpensive approximation - * of tertiles. Note, that pivot1 < pivot2. - */ - double pivot1 = a[e1]; - double pivot2 = a[e5]; - - /* - * The first and the last elements to be sorted are moved - * to the locations formerly occupied by the pivots. When - * partitioning is completed, the pivots are swapped back - * into their final positions, and excluded from the next - * subsequent sorting. - */ - a[e1] = a[lower]; - a[e5] = a[upper]; - - /* - * Skip elements, which are less or greater than the pivots. - */ - while (a[++lower] < pivot1); - while (a[--upper] > pivot2); - - /* - * Backward 3-interval partitioning - * - * left part central part right part - * +------------------------------------------------------------+ - * | < pivot1 | ? | pivot1 <= && <= pivot2 | > pivot2 | - * +------------------------------------------------------------+ - * ^ ^ ^ - * | | | - * lower k upper - * - * Invariants: - * - * all in (low, lower] < pivot1 - * pivot1 <= all in (k, upper) <= pivot2 - * all in [upper, end) > pivot2 - * - * Pointer k is the last index of ?-part - */ - for (int unused = --lower, k = ++upper; --k > lower; ) { - double ak = a[k]; - - if (ak < pivot1) { // Move a[k] to the left side - while (lower < k) { - if (a[++lower] >= pivot1) { - if (a[lower] > pivot2) { - a[k] = a[--upper]; - a[upper] = a[lower]; - } else { - a[k] = a[lower]; - } - a[lower] = ak; - break; - } - } - } else if (ak > pivot2) { // Move a[k] to the right side - a[k] = a[--upper]; - a[upper] = ak; - } - } - - /* - * Swap the pivots into their final positions. - */ - a[low] = a[lower]; a[lower] = pivot1; - a[end] = a[upper]; a[upper] = pivot2; - + * Use the first and fifth of the five sorted elements as + * the pivots. These values are inexpensive approximation + * of tertiles. Note, that pivot1 < pivot2. + */ + int[] pivotIndices = partition(double.class, a, Unsafe.ARRAY_DOUBLE_BASE_OFFSET, low, high, e1, e5, DualPivotQuicksort::partitionDualPivot); + lower = pivotIndices[0]; + upper = pivotIndices[1]; /* * Sort non-left parts recursively (possibly in parallel), * excluding known pivots. @@ -3463,73 +3617,183 @@ && tryMergeRuns(sorter, a, low, size)) { * Use the third of the five sorted elements as the pivot. * This value is inexpensive approximation of the median. */ - double pivot = a[e3]; + int[] pivotIndices = partition(double.class, a, Unsafe.ARRAY_DOUBLE_BASE_OFFSET, low, high, e3, e3, DualPivotQuicksort::partitionSinglePivot); + lower = pivotIndices[0]; + upper = pivotIndices[1]; /* - * The first element to be sorted is moved to the - * location formerly occupied by the pivot. After - * completion of partitioning the pivot is swapped - * back into its final position, and excluded from - * the next subsequent sorting. + * Sort the right part (possibly in parallel), excluding + * known pivot. All elements from the central part are + * equal and therefore already sorted. */ - a[e3] = a[lower]; + if (size > MIN_PARALLEL_SORT_SIZE && sorter != null) { + sorter.forkSorter(bits | 1, upper, high); + } else { + sort(sorter, a, bits | 1, upper, high); + } + } + high = lower; // Iterate along the left part + } + } - /* - * Traditional 3-way (Dutch National Flag) partitioning - * - * left part central part right part - * +------------------------------------------------------+ - * | < pivot | ? | == pivot | > pivot | - * +------------------------------------------------------+ - * ^ ^ ^ - * | | | - * lower k upper - * - * Invariants: - * - * all in (low, lower] < pivot - * all in (k, upper) == pivot - * all in [upper, end] > pivot - * - * Pointer k is the last index of ?-part - */ - for (int k = ++upper; --k > lower; ) { - double ak = a[k]; + /** + * Partitions the specified range of the array using the two pivots provided. + * + * @param array the array to be partitioned + * @param low the index of the first element, inclusive, for partitioning + * @param high the index of the last element, exclusive, for partitioning + * @param pivotIndex1 the index of pivot1, the first pivot + * @param pivotIndex2 the index of pivot2, the second pivot + * + */ + @ForceInline + private static int[] partitionDualPivot(double[] a, int low, int high, int pivotIndex1, int pivotIndex2) { + int end = high - 1; + int lower = low; + int upper = end; - if (ak != pivot) { - a[k] = pivot; + int e1 = pivotIndex1; + int e5 = pivotIndex2; + double pivot1 = a[e1]; + double pivot2 = a[e5]; - if (ak < pivot) { // Move a[k] to the left side - while (a[++lower] < pivot); + /* + * The first and the last elements to be sorted are moved + * to the locations formerly occupied by the pivots. When + * partitioning is completed, the pivots are swapped back + * into their final positions, and excluded from the next + * subsequent sorting. + */ + a[e1] = a[lower]; + a[e5] = a[upper]; - if (a[lower] > pivot) { - a[--upper] = a[lower]; - } - a[lower] = ak; - } else { // ak > pivot - Move a[k] to the right side - a[--upper] = ak; + /* + * Skip elements, which are less or greater than the pivots. + */ + while (a[++lower] < pivot1); + while (a[--upper] > pivot2); + + /* + * Backward 3-interval partitioning + * + * left part central part right part + * +------------------------------------------------------------+ + * | < pivot1 | ? | pivot1 <= && <= pivot2 | > pivot2 | + * +------------------------------------------------------------+ + * ^ ^ ^ + * | | | + * lower k upper + * + * Invariants: + * + * all in (low, lower] < pivot1 + * pivot1 <= all in (k, upper) <= pivot2 + * all in [upper, end) > pivot2 + * + * Pointer k is the last index of ?-part + */ + for (int unused = --lower, k = ++upper; --k > lower; ) { + double ak = a[k]; + + if (ak < pivot1) { // Move a[k] to the left side + while (lower < k) { + if (a[++lower] >= pivot1) { + if (a[lower] > pivot2) { + a[k] = a[--upper]; + a[upper] = a[lower]; + } else { + a[k] = a[lower]; } + a[lower] = ak; + break; } } + } else if (ak > pivot2) { // Move a[k] to the right side + a[k] = a[--upper]; + a[upper] = ak; + } + } - /* - * Swap the pivot into its final position. - */ - a[low] = a[lower]; a[lower] = pivot; + /* + * Swap the pivots into their final positions. + */ + a[low] = a[lower]; a[lower] = pivot1; + a[end] = a[upper]; a[upper] = pivot2; - /* - * Sort the right part (possibly in parallel), excluding - * known pivot. All elements from the central part are - * equal and therefore already sorted. - */ - if (size > MIN_PARALLEL_SORT_SIZE && sorter != null) { - sorter.forkSorter(bits | 1, upper, high); - } else { - sort(sorter, a, bits | 1, upper, high); + return new int[] {lower, upper}; + } + + /** + * Partitions the specified range of the array using a single pivot provided. + * + * @param array the array to be partitioned + * @param low the index of the first element, inclusive, for partitioning + * @param high the index of the last element, exclusive, for partitioning + * @param pivotIndex1 the index of pivot1, the first pivot + * @param pivotIndex2 the index of pivot2, the second pivot + */ + @ForceInline + private static int[] partitionSinglePivot(double[] a, int low, int high, int pivotIndex1, int pivotIndex2) { + + int end = high - 1; + int lower = low; + int upper = end; + + int e3 = pivotIndex1; + double pivot = a[e3]; + + /* + * The first element to be sorted is moved to the + * location formerly occupied by the pivot. After + * completion of partitioning the pivot is swapped + * back into its final position, and excluded from + * the next subsequent sorting. + */ + a[e3] = a[lower]; + + /* + * Traditional 3-way (Dutch National Flag) partitioning + * + * left part central part right part + * +------------------------------------------------------+ + * | < pivot | ? | == pivot | > pivot | + * +------------------------------------------------------+ + * ^ ^ ^ + * | | | + * lower k upper + * + * Invariants: + * + * all in (low, lower] < pivot + * all in (k, upper) == pivot + * all in [upper, end] > pivot + * + * Pointer k is the last index of ?-part + */ + for (int k = ++upper; --k > lower; ) { + double ak = a[k]; + + if (ak != pivot) { + a[k] = pivot; + + if (ak < pivot) { // Move a[k] to the left side + while (a[++lower] < pivot); + + if (a[lower] > pivot) { + a[--upper] = a[lower]; + } + a[lower] = ak; + } else { // ak > pivot - Move a[k] to the right side + a[--upper] = ak; } } - high = lower; // Iterate along the left part } + + /* + * Swap the pivot into its final position. + */ + a[low] = a[lower]; a[lower] = pivot; + return new int[] {lower, upper}; } /** @@ -3546,10 +3810,11 @@ && tryMergeRuns(sorter, a, low, size)) { * * @param a the array to be sorted * @param low the index of the first element, inclusive, to be sorted - * @param end the index of the last element for simple insertion sort * @param high the index of the last element, exclusive, to be sorted */ - private static void mixedInsertionSort(double[] a, int low, int end, int high) { + private static void mixedInsertionSort(double[] a, int low, int high) { + int size = high - low; + int end = high - 3 * ((size >> 5) << 3); if (end == high) { /* diff --git a/test/jdk/java/util/Arrays/Sorting.java b/test/jdk/java/util/Arrays/Sorting.java index e89496bb2e532..f285b0c65b72c 100644 --- a/test/jdk/java/util/Arrays/Sorting.java +++ b/test/jdk/java/util/Arrays/Sorting.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2009, 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2009, 2023, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -26,7 +26,8 @@ * @compile/module=java.base java/util/SortingHelper.java * @bug 6880672 6896573 6899694 6976036 7013585 7018258 8003981 8226297 * @build Sorting - * @run main Sorting -shortrun + * @run main/othervm -XX:+UnlockDiagnosticVMOptions -XX:DisableIntrinsic=_arraySort,_arrayPartition Sorting -shortrun + * @run main/othervm -XX:-TieredCompilation -XX:CompileCommand=CompileThresholdScaling,java.util.DualPivotQuicksort::sort,0.0001 Sorting -shortrun * @summary Exercise Arrays.sort, Arrays.parallelSort * * @author Vladimir Yaroslavskiy diff --git a/test/micro/org/openjdk/bench/java/util/ArraysSort.java b/test/micro/org/openjdk/bench/java/util/ArraysSort.java new file mode 100644 index 0000000000000..4cd45d79412c1 --- /dev/null +++ b/test/micro/org/openjdk/bench/java/util/ArraysSort.java @@ -0,0 +1,163 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package org.openjdk.bench.java.lang; + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OperationsPerInvocation; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; +import java.util.Arrays; +import java.util.Random; +import java.util.concurrent.TimeUnit; +import java.io.UnsupportedEncodingException; +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.reflect.Method; + +/** + * Performance test of Arrays.sort() methods + */ +@Fork(value=1, jvmArgsAppend={"-XX:CompileThreshold=1", "-XX:-TieredCompilation"}) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MICROSECONDS) +@State(Scope.Thread) +@Warmup(iterations = 3, time=5) +@Measurement(iterations = 3, time=3) +public class ArraysSort { + + @Param({"10","25","50","75","100", "1000", "10000", "100000", "1000000"}) + private int size; + + private int[] ints_unsorted; + private long[] longs_unsorted; + private float[] floats_unsorted; + private double[] doubles_unsorted; + + private int[] ints_sorted; + private long[] longs_sorted; + private float[] floats_sorted; + private double[] doubles_sorted; + + + public void initialize() { + Random rnd = new Random(42); + + ints_unsorted = new int[size]; + longs_unsorted = new long[size]; + floats_unsorted = new float[size]; + doubles_unsorted = new double[size]; + + int[] intSpecialCases = {Integer.MIN_VALUE, Integer.MAX_VALUE}; + long[] longSpecialCases = {Long.MIN_VALUE, Long.MAX_VALUE}; + float[] floatSpecialCases = {+0.0f, -0.0f, Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY, Float.NaN}; + double[] doubleSpecialCases = {+0.0, -0.0, Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, Double.NaN}; + + for (int i = 0; i < size; i++) { + ints_unsorted[i] = rnd.nextInt(); + longs_unsorted[i] = rnd.nextLong(); + if (i % 10 != 0) { + ints_unsorted[i] = rnd.nextInt(); + longs_unsorted[i] = rnd.nextLong(); + floats_unsorted[i] = rnd.nextFloat(); + doubles_unsorted[i] = rnd.nextDouble(); + } else { + ints_unsorted[i] = intSpecialCases[rnd.nextInt(intSpecialCases.length)]; + longs_unsorted[i] = longSpecialCases[rnd.nextInt(longSpecialCases.length)]; + floats_unsorted[i] = floatSpecialCases[rnd.nextInt(floatSpecialCases.length)]; + doubles_unsorted[i] = doubleSpecialCases[rnd.nextInt(doubleSpecialCases.length)]; + } + } + } + + @Setup + public void setup() throws UnsupportedEncodingException, ClassNotFoundException, NoSuchMethodException, Throwable { + initialize(); + } + + @Setup(Level.Invocation) + public void clear() { + ints_sorted = ints_unsorted.clone(); + longs_sorted = longs_unsorted.clone(); + floats_sorted = floats_unsorted.clone(); + doubles_sorted = doubles_unsorted.clone(); + } + + @Benchmark + public int[] intSort() throws Throwable { + Arrays.sort(ints_sorted); + return ints_sorted; + } + + @Benchmark + public int[] intParallelSort() throws Throwable { + Arrays.parallelSort(ints_sorted); + return ints_sorted; + } + + @Benchmark + public long[] longSort() throws Throwable { + Arrays.sort(longs_sorted); + return longs_sorted; + } + + @Benchmark + public long[] longParallelSort() throws Throwable { + Arrays.parallelSort(longs_sorted); + return longs_sorted; + } + + @Benchmark + public float[] floatSort() throws Throwable { + Arrays.sort(floats_sorted); + return floats_sorted; + } + + @Benchmark + public float[] floatParallelSort() throws Throwable { + Arrays.parallelSort(floats_sorted); + return floats_sorted; + } + + @Benchmark + public double[] doubleSort() throws Throwable { + Arrays.sort(doubles_sorted); + return doubles_sorted; + } + + @Benchmark + public double[] doubleParallelSort() throws Throwable { + Arrays.parallelSort(doubles_sorted); + return doubles_sorted; + } + +}