diff --git a/ci/BuildCPUBuilderImage.groovy b/ci/BuildCPUBuilderImage.groovy index 06022fe53..b4abfc3eb 100644 --- a/ci/BuildCPUBuilderImage.groovy +++ b/ci/BuildCPUBuilderImage.groovy @@ -92,9 +92,9 @@ spec: parameters{ string( - description: 'os(ubuntu20.04,centos7,ubuntu18.04)', + description: 'os(ubuntu22.04,ubuntu20.04,centos7,ubuntu18.04)', name: 'os', - defaultValue: 'ubuntu20.04' + defaultValue: 'ubuntu22.04' ) } stages { diff --git a/ci/BuildGPUBuilderImage.groovy b/ci/BuildGPUBuilderImage.groovy index a71ca2eca..0cc899e1a 100644 --- a/ci/BuildGPUBuilderImage.groovy +++ b/ci/BuildGPUBuilderImage.groovy @@ -92,9 +92,9 @@ spec: parameters{ string( - description: 'os(ubuntu20.04,centos7,ubuntu18.04)', + description: 'os(ubuntu22.04,ubuntu20.04,centos7,ubuntu18.04)', name: 'os', - defaultValue: 'ubuntu20.04' + defaultValue: 'ubuntu22.04' ) } stages { diff --git a/ci/docker/builder/cpu/ubuntu22.04/Dockerfile b/ci/docker/builder/cpu/ubuntu22.04/Dockerfile new file mode 100644 index 000000000..a2bbd57b9 --- /dev/null +++ b/ci/docker/builder/cpu/ubuntu22.04/Dockerfile @@ -0,0 +1,50 @@ +FROM ubuntu:22.04 + +ENV DEBIAN_FRONTEND=noninteractive +ENV CMAKE_VERSION="v3.28" +ENV CMAKE_TAR="cmake-3.28.4-linux-x86_64.tar.gz" +ENV CCACHE_VERSION="v4.9.1" +ENV CCACHE_TAR="ccache-4.9.1-linux-x86_64.tar.xz" + +RUN apt update \ + && apt install -y ca-certificates apt-transport-https software-properties-common lsb-release \ + && gpg --list-keys \ + && gpg --no-default-keyring --keyring /usr/share/keyrings/deadsnakes.gpg \ + --keyserver keyserver.ubuntu.com --recv-keys F23C5A6CF475977595C89F51BA6932366A755776 \ + && echo "deb [signed-by=/usr/share/keyrings/deadsnakes.gpg] https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu \ + $(lsb_release -cs) main" | tee /etc/apt/sources.list.d/python.list \ + && apt install -y --no-install-recommends wget curl git g++ gcc make gfortran swig \ + && apt install -y python3.11 python3.11-dev python3.11-distutils \ + && apt install -y python3-setuptools \ + && cd /usr/bin \ + && unlink python3 && ln -s python3.11 python3 \ + && curl -sS https://bootstrap.pypa.io/get-pip.py | python3 \ + && pip3 install wheel \ + && apt remove --purge -y \ + && rm -rf /var/lib/apt/lists/* + +# install cmake and ccache +RUN cd /tmp \ + && wget --tries=3 --retry-connrefused "https://cmake.org/files/${CMAKE_VERSION}/${CMAKE_TAR}" \ + && tar --strip-components=1 -xz -C /usr/local -f ${CMAKE_TAR} \ + && rm -f ${CMAKE_TAR} \ + && wget --tries=3 --retry-connrefused "https://github.com/ccache/ccache/releases/download/${CCACHE_VERSION}/${CCACHE_TAR}" \ + && tar -xf ${CCACHE_TAR} \ + && cp ccache-4.9.1-linux-x86_64/ccache /usr/local/bin \ + && rm -f ${CCACHE_TAR} + +# install knowhere dependancies +RUN apt update \ + && apt install -y libopenblas-dev libcurl4-openssl-dev libaio-dev libevent-dev lcov \ + && pip3 install conan==1.61.0 \ + && conan remote add default-conan-local https://milvus01.jfrog.io/artifactory/api/conan/default-conan-local \ + && export PATH=$PATH:$HOME/.local/bin + +# clone knowhere repo and build to update .conan +RUN git clone https://github.com/zilliztech/knowhere.git \ + && cd knowhere \ + && mkdir build && cd build \ + && conan install .. --build=missing -o with_ut=True -o with_diskann=True -s compiler.libcxx=libstdc++11 -s build_type=Release \ + && conan build .. \ + && cd ../.. \ + && rm -rf knowhere diff --git a/ci/docker/builder/gpu/ubuntu22.04/Dockerfile b/ci/docker/builder/gpu/ubuntu22.04/Dockerfile new file mode 100644 index 000000000..2dea66717 --- /dev/null +++ b/ci/docker/builder/gpu/ubuntu22.04/Dockerfile @@ -0,0 +1,50 @@ +FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 + +ENV DEBIAN_FRONTEND=noninteractive +ENV CMAKE_VERSION="v3.28" +ENV CMAKE_TAR="cmake-3.28.4-linux-x86_64.tar.gz" +ENV CCACHE_VERSION="v4.9.1" +ENV CCACHE_TAR="ccache-4.9.1-linux-x86_64.tar.xz" + +RUN apt update \ + && apt install -y ca-certificates apt-transport-https software-properties-common lsb-release \ + && gpg --list-keys \ + && gpg --no-default-keyring --keyring /usr/share/keyrings/deadsnakes.gpg \ + --keyserver keyserver.ubuntu.com --recv-keys F23C5A6CF475977595C89F51BA6932366A755776 \ + && echo "deb [signed-by=/usr/share/keyrings/deadsnakes.gpg] https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu \ + $(lsb_release -cs) main" | tee /etc/apt/sources.list.d/python.list \ + && apt install -y --no-install-recommends wget curl git g++ gcc make gfortran swig \ + && apt install -y python3.11 python3.11-dev python3.11-distutils \ + && apt install -y python3-setuptools \ + && cd /usr/bin \ + && unlink python3 && ln -s python3.11 python3 \ + && curl -sS https://bootstrap.pypa.io/get-pip.py | python3 \ + && pip3 install wheel \ + && apt remove --purge -y \ + && rm -rf /var/lib/apt/lists/* + +# install cmake and ccache +RUN cd /tmp \ + && wget --tries=3 --retry-connrefused "https://cmake.org/files/${CMAKE_VERSION}/${CMAKE_TAR}" \ + && tar --strip-components=1 -xz -C /usr/local -f ${CMAKE_TAR} \ + && rm -f ${CMAKE_TAR} \ + && wget --tries=3 --retry-connrefused "https://github.com/ccache/ccache/releases/download/${CCACHE_VERSION}/${CCACHE_TAR}" \ + && tar -xf ${CCACHE_TAR} \ + && cp ccache-4.9.1-linux-x86_64/ccache /usr/local/bin \ + && rm -f ${CCACHE_TAR} + +# install knowhere dependancies +RUN apt update \ + && apt install -y libopenblas-dev libcurl4-openssl-dev libaio-dev libevent-dev lcov \ + && pip3 install conan==1.61.0 \ + && conan remote add default-conan-local https://milvus01.jfrog.io/artifactory/api/conan/default-conan-local \ + && export PATH=$PATH:$HOME/.local/bin + +# clone knowhere repo and build to update .conan +RUN git clone https://github.com/zilliztech/knowhere.git \ + && cd knowhere \ + && mkdir build && cd build \ + && conan install .. --build=missing -o with_ut=True -o with_diskann=True -o with_raft=True -s compiler.libcxx=libstdc++11 -s build_type=Release \ + && conan build .. \ + && cd ../.. \ + && rm -rf knowhere diff --git a/cmake/libs/libcardinal.cmake b/cmake/libs/libcardinal.cmake index 1a6b0125f..25ae5c003 100644 --- a/cmake/libs/libcardinal.cmake +++ b/cmake/libs/libcardinal.cmake @@ -1,5 +1,5 @@ # Use short SHA1 as version -set(CARDINAL_VERSION 0a78003 ) +set(CARDINAL_VERSION c902b74 ) set(CARDINAL_REPO_URL "https://github.com/zilliztech/cardinal.git") set(CARDINAL_REPO_DIR "${CMAKE_CURRENT_SOURCE_DIR}/thirdparty/cardinal") diff --git a/include/knowhere/comp/knowhere_config.h b/include/knowhere/comp/knowhere_config.h index ba224cca7..d663a87d4 100644 --- a/include/knowhere/comp/knowhere_config.h +++ b/include/knowhere/comp/knowhere_config.h @@ -40,6 +40,17 @@ class KnowhereConfig { static std::string SetSimdType(const SimdType simd_type); + /** + *The purpose of this interface is: part of the sealed indexes default to using bf16 as the base data to achieve + *higher capacity; to ensure consistency in computation between growing and sealed, it is necessary to maintain the + *same precision in growing calculations as in sealed. + */ + static void + EnablePatchForComputeFP32AsBF16(); + + static void + DisablePatchForComputeFP32AsBF16(); + /** * Set openblas threshold * if nq < use_blas_threshold, calculated by omp diff --git a/include/knowhere/operands.h b/include/knowhere/operands.h index 91034603a..8873f053e 100644 --- a/include/knowhere/operands.h +++ b/include/knowhere/operands.h @@ -24,6 +24,13 @@ union fp32_bits { float as_value; }; +__attribute__((always_inline)) inline float +bf16_float(float f) { + auto u32 = fp32_bits{.as_value = f}.as_bits; + // Round off + return fp32_bits{.as_bits = (u32 + 0x8000) & 0xFFFF0000}.as_value; +} + inline float fp32_from_bits(const uint32_t& w) { return fp32_bits{.as_bits = w}.as_value; diff --git a/src/common/comp/knowhere_config.cc b/src/common/comp/knowhere_config.cc index 10fae424a..a49f5d7f1 100644 --- a/src/common/comp/knowhere_config.cc +++ b/src/common/comp/knowhere_config.cc @@ -90,6 +90,18 @@ KnowhereConfig::SetSimdType(const SimdType simd_type) { return simd_str; } +void +KnowhereConfig::EnablePatchForComputeFP32AsBF16() { + LOG_KNOWHERE_INFO_ << "Enable patch for compute fp32 as bf16"; + faiss::enable_patch_for_fp32_bf16(); +} + +void +KnowhereConfig::DisablePatchForComputeFP32AsBF16() { + LOG_KNOWHERE_INFO_ << "Disable patch for compute fp32 as bf16"; + faiss::disable_patch_for_fp32_bf16(); +} + void KnowhereConfig::SetBlasThreshold(const int64_t use_blas_threshold) { LOG_KNOWHERE_INFO_ << "Set faiss::distance_compute_blas_threshold to " << use_blas_threshold; diff --git a/src/simd/distances_avx.cc b/src/simd/distances_avx.cc index bfb85f5f2..d12b8300f 100644 --- a/src/simd/distances_avx.cc +++ b/src/simd/distances_avx.cc @@ -18,6 +18,7 @@ #include #include "faiss/impl/platform_macros.h" +#include "knowhere/operands.h" namespace faiss { @@ -54,6 +55,20 @@ fvec_inner_product_avx(const float* x, const float* y, size_t d) { } FAISS_PRAGMA_IMPRECISE_FUNCTION_END +// trust the compiler to unroll this properly +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +float +fvec_inner_product_avx_bf16_patch(const float* x, const float* y, size_t d) { + size_t i; + float res = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (i = 0; i < d; i++) { + res += x[i] * bf16_float(y[i]); + } + return res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + // trust the compiler to unroll this properly FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN float @@ -69,6 +84,20 @@ fvec_L2sqr_avx(const float* x, const float* y, size_t d) { } FAISS_PRAGMA_IMPRECISE_FUNCTION_END +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +float +fvec_L2sqr_avx_bf16_patch(const float* x, const float* y, size_t d) { + size_t i; + float res = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (i = 0; i < d; i++) { + const float tmp = x[i] - bf16_float(y[i]); + res += tmp * tmp; + } + return res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + float fvec_L1_avx(const float* x, const float* y, size_t d) { __m256 msum1 = _mm256_setzero_ps(); @@ -187,6 +216,32 @@ fvec_inner_product_batch_4_avx(const float* __restrict x, const float* __restric } FAISS_PRAGMA_IMPRECISE_FUNCTION_END +// trust the compiler to unroll this properly +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +void +fvec_inner_product_batch_4_avx_bf16_patch(const float* __restrict x, const float* __restrict y0, + const float* __restrict y1, const float* __restrict y2, + const float* __restrict y3, const size_t d, float& dis0, float& dis1, + float& dis2, float& dis3) { + float d0 = 0; + float d1 = 0; + float d2 = 0; + float d3 = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i < d; ++i) { + d0 += x[i] * bf16_float(y0[i]); + d1 += x[i] * bf16_float(y1[i]); + d2 += x[i] * bf16_float(y2[i]); + d3 += x[i] * bf16_float(y3[i]); + } + + dis0 = d0; + dis1 = d1; + dis2 = d2; + dis3 = d3; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + // trust the compiler to unroll this properly FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN void @@ -215,6 +270,34 @@ fvec_L2sqr_batch_4_avx(const float* x, const float* y0, const float* y1, const f } FAISS_PRAGMA_IMPRECISE_FUNCTION_END +// trust the compiler to unroll this properly +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +void +fvec_L2sqr_batch_4_avx_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, + const size_t d, float& dis0, float& dis1, float& dis2, float& dis3) { + float d0 = 0; + float d1 = 0; + float d2 = 0; + float d3 = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i < d; ++i) { + const float q0 = x[i] - bf16_float(y0[i]); + const float q1 = x[i] - bf16_float(y1[i]); + const float q2 = x[i] - bf16_float(y2[i]); + const float q3 = x[i] - bf16_float(y3[i]); + d0 += q0 * q0; + d1 += q1 * q1; + d2 += q2 * q2; + d3 += q3 * q3; + } + + dis0 = d0; + dis1 = d1; + dis2 = d2; + dis3 = d3; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + // trust the compiler to unroll this properly int32_t ivec_inner_product_avx(const int8_t* x, const int8_t* y, size_t d) { diff --git a/src/simd/distances_avx.h b/src/simd/distances_avx.h index f786dc693..f6678d03c 100644 --- a/src/simd/distances_avx.h +++ b/src/simd/distances_avx.h @@ -21,10 +21,15 @@ namespace faiss { float fvec_L2sqr_avx(const float* x, const float* y, size_t d); +float +fvec_L2sqr_avx_bf16_patch(const float* x, const float* y, size_t d); + /// inner product float fvec_inner_product_avx(const float* x, const float* y, size_t d); +float +fvec_inner_product_avx_bf16_patch(const float* x, const float* y, size_t d); /// L1 distance float fvec_L1_avx(const float* x, const float* y, size_t d); @@ -40,10 +45,19 @@ void fvec_inner_product_batch_4_avx(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); +void +fvec_inner_product_batch_4_avx_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, + const float* y3, const size_t d, float& dis0, float& dis1, float& dis2, + float& dis3); + void fvec_L2sqr_batch_4_avx(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); +void +fvec_L2sqr_batch_4_avx_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, + const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); + int32_t ivec_inner_product_avx(const int8_t* x, const int8_t* y, size_t d); diff --git a/src/simd/distances_avx512.cc b/src/simd/distances_avx512.cc index 50cfbd153..32d634afc 100644 --- a/src/simd/distances_avx512.cc +++ b/src/simd/distances_avx512.cc @@ -19,6 +19,7 @@ #include #include "faiss/impl/platform_macros.h" +#include "knowhere/operands.h" namespace faiss { @@ -53,6 +54,19 @@ fvec_inner_product_avx512(const float* x, const float* y, size_t d) { } FAISS_PRAGMA_IMPRECISE_FUNCTION_END +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +float +fvec_inner_product_avx512_bf16_patch(const float* x, const float* y, size_t d) { + size_t i; + float res = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (i = 0; i < d; i++) { + res += x[i] * bf16_float(y[i]); + } + return res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + // trust the compiler to unroll this properly FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN float @@ -68,6 +82,20 @@ fvec_L2sqr_avx512(const float* x, const float* y, size_t d) { } FAISS_PRAGMA_IMPRECISE_FUNCTION_END +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +float +fvec_L2sqr_avx512_bf16_patch(const float* x, const float* y, size_t d) { + size_t i; + float res = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (i = 0; i < d; i++) { + const float tmp = x[i] - bf16_float(y[i]); + res += tmp * tmp; + } + return res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + float fvec_L1_avx512(const float* x, const float* y, size_t d) { __m512 msum0 = _mm512_setzero_ps(); @@ -214,6 +242,32 @@ fvec_inner_product_batch_4_avx512(const float* __restrict x, const float* __rest } FAISS_PRAGMA_IMPRECISE_FUNCTION_END +// trust the compiler to unroll this properly +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +void +fvec_inner_product_batch_4_avx512_bf16_patch(const float* __restrict x, const float* __restrict y0, + const float* __restrict y1, const float* __restrict y2, + const float* __restrict y3, const size_t d, float& dis0, float& dis1, + float& dis2, float& dis3) { + float d0 = 0; + float d1 = 0; + float d2 = 0; + float d3 = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i < d; ++i) { + d0 += x[i] * bf16_float(y0[i]); + d1 += x[i] * bf16_float(y1[i]); + d2 += x[i] * bf16_float(y2[i]); + d3 += x[i] * bf16_float(y3[i]); + } + + dis0 = d0; + dis1 = d1; + dis2 = d2; + dis3 = d3; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + // trust the compiler to unroll this properly FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN void @@ -242,6 +296,34 @@ fvec_L2sqr_batch_4_avx512(const float* x, const float* y0, const float* y1, cons } FAISS_PRAGMA_IMPRECISE_FUNCTION_END +// trust the compiler to unroll this properly +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +void +fvec_L2sqr_batch_4_avx512_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, + const size_t d, float& dis0, float& dis1, float& dis2, float& dis3) { + float d0 = 0; + float d1 = 0; + float d2 = 0; + float d3 = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i < d; ++i) { + const float q0 = x[i] - bf16_float(y0[i]); + const float q1 = x[i] - bf16_float(y1[i]); + const float q2 = x[i] - bf16_float(y2[i]); + const float q3 = x[i] - bf16_float(y3[i]); + d0 += q0 * q0; + d1 += q1 * q1; + d2 += q2 * q2; + d3 += q3 * q3; + } + + dis0 = d0; + dis1 = d1; + dis2 = d2; + dis3 = d3; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + // trust the compiler to unroll this properly int32_t ivec_inner_product_avx512(const int8_t* x, const int8_t* y, size_t d) { diff --git a/src/simd/distances_avx512.h b/src/simd/distances_avx512.h index 79654319f..6069d4d77 100644 --- a/src/simd/distances_avx512.h +++ b/src/simd/distances_avx512.h @@ -20,10 +20,16 @@ namespace faiss { float fvec_L2sqr_avx512(const float* x, const float* y, size_t d); +float +fvec_L2sqr_avx512_bf16_patch(const float* x, const float* y, size_t d); + /// inner product float fvec_inner_product_avx512(const float* x, const float* y, size_t d); +float +fvec_inner_product_avx512_bf16_patch(const float* x, const float* y, size_t d); + /// L1 distance float fvec_L1_avx512(const float* x, const float* y, size_t d); @@ -39,10 +45,19 @@ void fvec_inner_product_batch_4_avx512(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); +void +fvec_inner_product_batch_4_avx512_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, + const float* y3, const size_t d, float& dis0, float& dis1, float& dis2, + float& dis3); + void fvec_L2sqr_batch_4_avx512(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); +void +fvec_L2sqr_batch_4_avx512_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, + const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); + int32_t ivec_inner_product_avx512(const int8_t* x, const int8_t* y, size_t d); diff --git a/src/simd/distances_ref.cc b/src/simd/distances_ref.cc index 2ff6e82f6..283500aa3 100644 --- a/src/simd/distances_ref.cc +++ b/src/simd/distances_ref.cc @@ -11,6 +11,8 @@ #include +#include "knowhere/operands.h" + namespace faiss { float @@ -24,6 +26,17 @@ fvec_L2sqr_ref(const float* x, const float* y, size_t d) { return res; } +float +fvec_L2sqr_ref_bf16_patch(const float* x, const float* y, size_t d) { + size_t i; + float res = 0; + for (i = 0; i < d; i++) { + const float tmp = x[i] - bf16_float(y[i]); + res += tmp * tmp; + } + return res; +} + float fvec_L1_ref(const float* x, const float* y, size_t d) { size_t i; @@ -55,6 +68,16 @@ fvec_inner_product_ref(const float* x, const float* y, size_t d) { return res; } +float +fvec_inner_product_ref_bf16_patch(const float* x, const float* y, size_t d) { + size_t i; + float res = 0; + for (i = 0; i < d; i++) { + res += x[i] * bf16_float(y[i]); + } + return res; +} + float fvec_norm_L2sqr_ref(const float* x, size_t d) { size_t i; @@ -188,6 +211,28 @@ fvec_inner_product_batch_4_ref(const float* __restrict x, const float* __restric dis3 = d3; } +void +fvec_inner_product_batch_4_ref_bf16_patch(const float* __restrict x, const float* __restrict y0, + const float* __restrict y1, const float* __restrict y2, + const float* __restrict y3, const size_t d, float& dis0, float& dis1, + float& dis2, float& dis3) { + float d0 = 0; + float d1 = 0; + float d2 = 0; + float d3 = 0; + for (size_t i = 0; i < d; ++i) { + d0 += x[i] * bf16_float(y0[i]); + d1 += x[i] * bf16_float(y1[i]); + d2 += x[i] * bf16_float(y2[i]); + d3 += x[i] * bf16_float(y3[i]); + } + + dis0 = d0; + dis1 = d1; + dis2 = d2; + dis3 = d3; +} + void fvec_L2sqr_batch_4_ref(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, const size_t d, float& dis0, float& dis1, float& dis2, float& dis3) { @@ -212,6 +257,30 @@ fvec_L2sqr_batch_4_ref(const float* x, const float* y0, const float* y1, const f dis3 = d3; } +void +fvec_L2sqr_batch_4_ref_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, + const size_t d, float& dis0, float& dis1, float& dis2, float& dis3) { + float d0 = 0; + float d1 = 0; + float d2 = 0; + float d3 = 0; + for (size_t i = 0; i < d; ++i) { + const float q0 = x[i] - bf16_float(y0[i]); + const float q1 = x[i] - bf16_float(y1[i]); + const float q2 = x[i] - bf16_float(y2[i]); + const float q3 = x[i] - bf16_float(y3[i]); + d0 += q0 * q0; + d1 += q1 * q1; + d2 += q2 * q2; + d3 += q3 * q3; + } + + dis0 = d0; + dis1 = d1; + dis2 = d2; + dis3 = d3; +} + int32_t ivec_inner_product_ref(const int8_t* x, const int8_t* y, size_t d) { size_t i; diff --git a/src/simd/distances_ref.h b/src/simd/distances_ref.h index 2ca812051..6504aa6bd 100644 --- a/src/simd/distances_ref.h +++ b/src/simd/distances_ref.h @@ -10,10 +10,16 @@ namespace faiss { float fvec_L2sqr_ref(const float* x, const float* y, size_t d); +float +fvec_L2sqr_ref_bf16_patch(const float* x, const float* y, size_t d); + /// inner product float fvec_inner_product_ref(const float* x, const float* y, size_t d); +float +fvec_inner_product_ref_bf16_patch(const float* x, const float* y, size_t d); + /// L1 distance float fvec_L1_ref(const float* x, const float* y, size_t d); @@ -66,12 +72,21 @@ void fvec_inner_product_batch_4_ref(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); +void +fvec_inner_product_batch_4_ref_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, + const float* y3, const size_t d, float& dis0, float& dis1, float& dis2, + float& dis3); + /// Special version of L2sqr that computes 4 distances /// between x and yi, which is performance oriented. void fvec_L2sqr_batch_4_ref(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); +void +fvec_L2sqr_batch_4_ref_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, + const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); + int32_t ivec_inner_product_ref(const int8_t* x, const int8_t* y, size_t d); diff --git a/src/simd/hook.cc b/src/simd/hook.cc index 5e40b3f05..812ea9a60 100644 --- a/src/simd/hook.cc +++ b/src/simd/hook.cc @@ -82,6 +82,70 @@ cpu_support_sse4_2() { } #endif +static std::mutex patch_bf16_mutex; + +void +enable_patch_for_fp32_bf16() { + std::lock_guard lock(patch_bf16_mutex); +#if defined(__x86_64__) + if (use_avx512 && cpu_support_avx512()) { + // Cloud branch + fvec_inner_product = fvec_inner_product_avx512_bf16_patch; + fvec_L2sqr = fvec_L2sqr_avx512_bf16_patch; + + fvec_inner_product_batch_4 = fvec_inner_product_batch_4_avx512_bf16_patch; + fvec_L2sqr_batch_4 = fvec_L2sqr_batch_4_avx512_bf16_patch; + + } else if (use_avx2 && cpu_support_avx2()) { + fvec_inner_product = fvec_inner_product_avx_bf16_patch; + fvec_L2sqr = fvec_L2sqr_avx_bf16_patch; + + fvec_inner_product_batch_4 = fvec_inner_product_batch_4_avx_bf16_patch; + fvec_L2sqr_batch_4 = fvec_L2sqr_batch_4_avx_bf16_patch; + + } else if (use_sse4_2 && cpu_support_sse4_2()) { + // The branch that can't be reached + } else { + fvec_inner_product = fvec_inner_product_ref_bf16_patch; + fvec_L2sqr = fvec_L2sqr_ref_bf16_patch; + + fvec_inner_product_batch_4 = fvec_inner_product_batch_4_ref_bf16_patch; + fvec_L2sqr_batch_4 = fvec_L2sqr_batch_4_ref_bf16_patch; + } +#endif +} + +void +disable_patch_for_fp32_bf16() { + std::lock_guard lock(patch_bf16_mutex); +#if defined(__x86_64__) + if (use_avx512 && cpu_support_avx512()) { + // Cloud branch + fvec_inner_product = fvec_inner_product_avx512; + fvec_L2sqr = fvec_L2sqr_avx512; + + fvec_inner_product_batch_4 = fvec_inner_product_batch_4_avx512; + fvec_L2sqr_batch_4 = fvec_L2sqr_batch_4_avx512; + + } else if (use_avx2 && cpu_support_avx2()) { + fvec_inner_product = fvec_inner_product_avx; + fvec_L2sqr = fvec_L2sqr_avx; + + fvec_inner_product_batch_4 = fvec_inner_product_batch_4_avx; + fvec_L2sqr_batch_4 = fvec_L2sqr_batch_4_avx; + + } else if (use_sse4_2 && cpu_support_sse4_2()) { + // The branch that can't be reached + } else { + fvec_inner_product = fvec_inner_product_ref; + fvec_L2sqr = fvec_L2sqr_ref; + + fvec_inner_product_batch_4 = fvec_inner_product_batch_4_ref; + fvec_L2sqr_batch_4 = fvec_L2sqr_batch_4_ref; + } +#endif +} + void fvec_hook(std::string& simd_type) { static std::mutex hook_mutex; diff --git a/src/simd/hook.h b/src/simd/hook.h index 69d84e9b4..40488f169 100644 --- a/src/simd/hook.h +++ b/src/simd/hook.h @@ -91,6 +91,12 @@ bool cpu_support_sse4_2(); #endif +void +enable_patch_for_fp32_bf16(); + +void +disable_patch_for_fp32_bf16(); + void fvec_hook(std::string&); diff --git a/tests/ut/test_type.cc b/tests/ut/test_type.cc new file mode 100644 index 000000000..8484bf616 --- /dev/null +++ b/tests/ut/test_type.cc @@ -0,0 +1,101 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "catch2/catch_test_macros.hpp" +#include "catch2/generators/catch_generators.hpp" +#include "knowhere/comp/knowhere_config.h" +#include "simd/distances_ref.h" +#include "simd/hook.h" +#include "utils.h" + +TEST_CASE("Test bf16 patch", "[bf16 patch]") { + const int64_t nb = 1000, nq = 10; + const int64_t dim = 128; + + const auto train_ds = GenFloatDataSet(nb, dim); + const auto query_ds = GenFloatDataSet(nq, dim); + + auto train_tensor = reinterpret_cast(train_ds->GetTensor()); + + std::vector l2_dist(nb * nq); + std::vector l2_dist_patch(nb * nq); + std::vector l2_dist_patch_ref(nb * nq); + + std::vector ip_dist(nb * nq); + std::vector ip_dist_patch(nb * nq); + std::vector ip_dist_patch_ref(nb * nq); + + auto query_tensor = reinterpret_cast(query_ds->GetTensor()); + + auto compute_dist = [&](float* l2_dist_ret, float* ip_dist_ret) { + for (int64_t i = 0; i < nq; i++) { + for (int64_t j = 0; j < nb; j++) { + l2_dist_ret[i * nq + j] = faiss::fvec_L2sqr(query_tensor + i * dim, train_tensor + j * dim, dim); + ip_dist_ret[i * nq + j] = + faiss::fvec_inner_product(query_tensor + i * dim, train_tensor + j * dim, dim); + } + } + }; + + auto compute_ref_dist = [&](float* l2_dist_ret, float* ip_dist_ret) { + for (int64_t i = 0; i < nq; i++) { + for (int64_t j = 0; j < nb; j++) { + l2_dist_ret[i * nq + j] = faiss::fvec_L2sqr_ref(query_tensor + i * dim, train_tensor + j * dim, dim); + ip_dist_ret[i * nq + j] = + faiss::fvec_inner_product_ref(query_tensor + i * dim, train_tensor + j * dim, dim); + } + } + }; + + compute_dist(l2_dist.data(), ip_dist.data()); + + knowhere::KnowhereConfig::EnablePatchForComputeFP32AsBF16(); + + compute_dist(l2_dist_patch.data(), ip_dist_patch.data()); + compute_ref_dist(l2_dist_patch_ref.data(), ip_dist_patch_ref.data()); + + double l2_error_sum = 0.0; + double l2_ref_error_sum = 0.0; + double l2_total_sum = 0.0; + + double ip_error_sum = 0.0; + double ip_ref_error_sum = 0.0; + double ip_total_sum = 0.0; + for (int64_t i = 0; i < nq * nb; i++) { + l2_error_sum += abs(l2_dist[i] - l2_dist_patch[i]); + l2_ref_error_sum += abs(l2_dist[i] - l2_dist_patch_ref[i]); + l2_total_sum += abs(l2_dist[i]); + + ip_error_sum += abs(ip_dist[i] - ip_dist_patch[i]); + ip_ref_error_sum += abs(ip_dist[i] - ip_dist_patch_ref[i]); + ip_total_sum += abs(ip_dist[i]); + } + + double l2_relative_error = l2_error_sum / l2_total_sum; + double l2_ref_relative_error = l2_ref_error_sum / l2_total_sum; + double ip_relative_error = ip_error_sum / ip_total_sum; + double ip_ref_relative_error = ip_ref_error_sum / ip_total_sum; + + REQUIRE(l2_relative_error < pow(2, -11.0)); + REQUIRE(l2_ref_relative_error < pow(2, -11.0)); + REQUIRE(ip_relative_error < pow(2, -11.0)); + REQUIRE(ip_ref_relative_error < pow(2, -11.0)); + + knowhere::KnowhereConfig::DisablePatchForComputeFP32AsBF16(); + + std::vector l2_dist_new(nb * nq); + std::vector ip_dist_new(nb * nq); + compute_dist(l2_dist_new.data(), ip_dist_new.data()); + for (int64_t i = 0; i < nq * nb; i++) { + REQUIRE(l2_dist[i] == l2_dist_new[i]); + REQUIRE(ip_dist[i] == ip_dist_new[i]); + } +} diff --git a/tests/ut/utils.h b/tests/ut/utils.h index 6ecb3a4c7..3f15fee63 100644 --- a/tests/ut/utils.h +++ b/tests/ut/utils.h @@ -47,6 +47,17 @@ GenDataSet(int rows, int dim, int seed = 42) { return ds; } +inline knowhere::DataSetPtr +GenFloatDataSet(int rows, int dim, int seed = 42) { + std::mt19937 rng(seed); + std::uniform_real_distribution<> distrib(0.0, 100.0); + float* ts = new float[rows * dim]; + for (int i = 0; i < rows * dim; ++i) ts[i] = (float)distrib(rng); + auto ds = knowhere::GenDataSet(rows, dim, ts); + ds->SetIsOwner(true); + return ds; +} + inline knowhere::DataSetPtr CopyDataSet(knowhere::DataSetPtr dataset, const int64_t copy_rows) { REQUIRE(!dataset->GetIsSparse());