From 9a9aac55fed09d0c229ff32341e9cc85c615ee45 Mon Sep 17 00:00:00 2001 From: cqy123456 <39671710+cqy123456@users.noreply.github.com> Date: Thu, 12 Jan 2023 23:11:37 -0500 Subject: [PATCH] replace diskann distance function by faiss distance function (#634) Signed-off-by: cqy123456 --- knowhere/index/vector_index/IndexDiskANN.cpp | 2 - knowhere/index/vector_index/IndexDiskANN.h | 3 +- python/knowhere/__init__.py | 13 +- python/knowhere/knowhere.i | 18 - thirdparty/DiskANN/CMakeLists.txt | 4 +- .../DiskANN/include/cosine_similarity.h | 275 -------- thirdparty/DiskANN/include/distance.h | 154 +---- thirdparty/DiskANN/include/index.h | 2 +- thirdparty/DiskANN/include/percentile_stats.h | 1 - thirdparty/DiskANN/include/pq_flash_index.h | 7 +- thirdparty/DiskANN/include/utils.h | 6 +- thirdparty/DiskANN/src/distance.cpp | 598 +++--------------- thirdparty/DiskANN/src/index.cpp | 107 ++-- thirdparty/DiskANN/src/pq_flash_index.cpp | 19 +- thirdparty/DiskANN/src/utils.cpp | 153 ----- unittest/AsyncIndex.h | 6 - 16 files changed, 150 insertions(+), 1218 deletions(-) delete mode 100644 thirdparty/DiskANN/include/cosine_similarity.h diff --git a/knowhere/index/vector_index/IndexDiskANN.cpp b/knowhere/index/vector_index/IndexDiskANN.cpp index 6e32c377f..da1ee881d 100644 --- a/knowhere/index/vector_index/IndexDiskANN.cpp +++ b/knowhere/index/vector_index/IndexDiskANN.cpp @@ -536,6 +536,4 @@ IndexDiskANN::GetCachedNodeNum(const float cache_dram_budget, const uint64_t // Explicit template instantiation template class IndexDiskANN; -template class IndexDiskANN; -template class IndexDiskANN; } // namespace knowhere diff --git a/knowhere/index/vector_index/IndexDiskANN.h b/knowhere/index/vector_index/IndexDiskANN.h index dd3cf21ac..6f35791ff 100644 --- a/knowhere/index/vector_index/IndexDiskANN.h +++ b/knowhere/index/vector_index/IndexDiskANN.h @@ -24,8 +24,7 @@ namespace knowhere { template class IndexDiskANN : public VecIndex { - static_assert(std::is_same_v || std::is_same_v || std::is_same_v, - "DiskANN only support float, int8 and uint8"); + static_assert(std::is_same_v, "DiskANN only support float"); public: explicit IndexDiskANN(std::string index_prefix, MetricType metric_type, std::shared_ptr file_manager); diff --git a/python/knowhere/__init__.py b/python/knowhere/__init__.py index a19c73339..c0eff500a 100644 --- a/python/knowhere/__init__.py +++ b/python/knowhere/__init__.py @@ -49,13 +49,8 @@ def CreateIndexDiskANN(index_name, index_prefix, metric_type, simd_type="auto"): if index_name == "diskann_f": return buildDiskANNf(index_prefix, metric_type) - if index_name == "diskann_i8": - return buildDiskANNi8(index_prefix, metric_type) - if index_name == "diskann_ui8": - return buildDiskANNui8(index_prefix, metric_type) raise ValueError( - """ index name only support - 'diskann_f' 'diskann_i8' 'diskann_ui8'.""" + """ index name only support 'diskann_f'. """ ) @@ -67,15 +62,15 @@ def CreateAsyncIndex(index_name, index_prefix="", metric_type="", simd_type="aut if index_name not in ["bin_flat", "bin_ivf_flat", "flat", "ivf_flat", "ivf_pq", "ivf_sq8", "hnsw", "annoy", "gpu_flat", "gpu_ivf_flat", - "gpu_ivf_pq", "gpu_ivf_sq8", "diskann_f", "diskann_i8", "diskann_ui8"]: + "gpu_ivf_pq", "gpu_ivf_sq8", "diskann_f"]: raise ValueError( """ index name only support 'bin_flat', 'bin_ivf_flat', 'flat', 'ivf_flat', 'ivf_pq', 'ivf_sq8', 'hnsw', 'annoy', 'gpu_flat', 'gpu_ivf_flat', - 'gpu_ivf_pq', 'gpu_ivf_sq8', 'diskann_f', 'diskann_i8', 'diskann_ui8'.""" + 'gpu_ivf_pq', 'gpu_ivf_sq8', 'diskann_f'.""" ) - if index_name in ["diskann_f", "diskann_i8", "diskann_ui8"]: + if index_name == "diskann_f": if index_prefix == "": raise ValueError("Must pass index_prefix to DiskANN") if metric_type == "": diff --git a/python/knowhere/knowhere.i b/python/knowhere/knowhere.i index 3ecad3d96..44ead7c6f 100644 --- a/python/knowhere/knowhere.i +++ b/python/knowhere/knowhere.i @@ -77,8 +77,6 @@ import_array(); %template(DatasetPtr) std::shared_ptr; #ifdef KNOWHERE_WITH_DISKANN %shared_ptr(knowhere::IndexDiskANN) -%shared_ptr(knowhere::IndexDiskANN) -%shared_ptr(knowhere::IndexDiskANN) #endif %shared_ptr(knowhere::AsyncIndex) %include @@ -130,8 +128,6 @@ import_array(); // Support for DiskANN #ifdef KNOWHERE_WITH_DISKANN %template(IndexDiskANNf) knowhere::IndexDiskANN; -%template(IndexDiskANNi8) knowhere::IndexDiskANN; -%template(IndexDiskANNui8) knowhere::IndexDiskANN; %inline %{ @@ -145,20 +141,6 @@ buildDiskANNf(std::string index_prefix, std::string metric_type) { std::shared_ptr(new knowhere::LocalFileManager)); } -std::shared_ptr> -buildDiskANNi8(std::string index_prefix, std::string metric_type) { - TOUPPER(metric_type); - return std::make_shared>(index_prefix, metric_type, - std::shared_ptr(new knowhere::LocalFileManager)); -} - -std::shared_ptr> -buildDiskANNui8(std::string index_prefix, std::string metric_type) { - TOUPPER(metric_type); - return std::make_shared>(index_prefix, metric_type, - std::shared_ptr(new knowhere::LocalFileManager)); -} - %} #endif diff --git a/thirdparty/DiskANN/CMakeLists.txt b/thirdparty/DiskANN/CMakeLists.txt index 5287ce7b7..7425ed534 100644 --- a/thirdparty/DiskANN/CMakeLists.txt +++ b/thirdparty/DiskANN/CMakeLists.txt @@ -217,8 +217,8 @@ else() set(ENV{TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD} 500000000000) # set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -DDEBUG -O0 -fsanitize=address -fsanitize=leak -fsanitize=undefined") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -DDEBUG -Wall -Wextra") - set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -Ofast -DELPP_DISABLE_DEBUG_LOGS -DNDEBUG -march=skylake -ftree-vectorize") - add_compile_options(-march=skylake -Wall -fno-builtin-malloc -fno-builtin-calloc -fno-builtin-realloc -fno-builtin-free -fopenmp -fopenmp-simd -funroll-loops -Wfatal-errors -DUSE_AVX2) + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -Ofast -DELPP_DISABLE_DEBUG_LOGS -DNDEBUG -ftree-vectorize") + add_compile_options(-Wall -fno-builtin-malloc -fno-builtin-calloc -fno-builtin-realloc -fno-builtin-free -fopenmp -fopenmp-simd -funroll-loops -Wfatal-errors) endif() add_definitions( -DAUTO_INITIALIZE_EASYLOGGINGPP ) diff --git a/thirdparty/DiskANN/include/cosine_similarity.h b/thirdparty/DiskANN/include/cosine_similarity.h deleted file mode 100644 index b2fd9f1fe..000000000 --- a/thirdparty/DiskANN/include/cosine_similarity.h +++ /dev/null @@ -1,275 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "simd_utils.h" - -extern bool Avx2SupportedCPU; - -#ifdef _WINDOWS -// SIMD implementation of Cosine similarity. Taken from hnsw library. - -/** - * Non-metric Space Library - * - * Authors: Bilegsaikhan Naidan (https://github.com/bileg), Leonid Boytsov - * (http://boytsov.info). With contributions from Lawrence Cayton - * (http://lcayton.com/) and others. - * - * For the complete list of contributors and further details see: - * https://github.com/searchivarius/NonMetricSpaceLib - * - * Copyright (c) 2014 - * - * This code is released under the - * Apache License Version 2.0 http://www.apache.org/licenses/. - * - */ - -namespace diskann { - - using namespace std; - -#define PORTABLE_ALIGN16 __declspec(align(16)) - - static float NormScalarProductSIMD2(const int8_t* pVect1, - const int8_t* pVect2, uint32_t qty) { - if (Avx2SupportedCPU) { - __m256 cos, p1Len, p2Len; - cos = p1Len = p2Len = _mm256_setzero_ps(); - while (qty >= 32) { - __m256i rx = _mm256_load_si256((__m256i*) pVect1), - ry = _mm256_load_si256((__m256i*) pVect2); - cos = _mm256_add_ps(cos, _mm256_mul_epi8(rx, ry)); - p1Len = _mm256_add_ps(p1Len, _mm256_mul_epi8(rx, rx)); - p2Len = _mm256_add_ps(p2Len, _mm256_mul_epi8(ry, ry)); - pVect1 += 32; - pVect2 += 32; - qty -= 32; - } - while (qty > 0) { - __m128i rx = _mm_load_si128((__m128i*) pVect1), - ry = _mm_load_si128((__m128i*) pVect2); - cos = _mm256_add_ps(cos, _mm256_mul32_pi8(rx, ry)); - p1Len = _mm256_add_ps(p1Len, _mm256_mul32_pi8(rx, rx)); - p2Len = _mm256_add_ps(p2Len, _mm256_mul32_pi8(ry, ry)); - pVect1 += 4; - pVect2 += 4; - qty -= 4; - } - cos = _mm256_hadd_ps(_mm256_hadd_ps(cos, cos), cos); - p1Len = _mm256_hadd_ps(_mm256_hadd_ps(p1Len, p1Len), p1Len); - p2Len = _mm256_hadd_ps(_mm256_hadd_ps(p2Len, p2Len), p2Len); - float denominator = max(numeric_limits::min() * 2, - sqrt(p1Len.m256_f32[0] + p1Len.m256_f32[4]) * - sqrt(p2Len.m256_f32[0] + p2Len.m256_f32[4])); - float cosine = (cos.m256_f32[0] + cos.m256_f32[4]) / denominator; - - return max(float(-1), min(float(1), cosine)); - } - - __m128 cos, p1Len, p2Len; - cos = p1Len = p2Len = _mm_setzero_ps(); - __m128i rx, ry; - while (qty >= 16) { - rx = _mm_load_si128((__m128i*) pVect1); - ry = _mm_load_si128((__m128i*) pVect2); - cos = _mm_add_ps(cos, _mm_mul_epi8(rx, ry)); - p1Len = _mm_add_ps(p1Len, _mm_mul_epi8(rx, rx)); - p2Len = _mm_add_ps(p2Len, _mm_mul_epi8(ry, ry)); - pVect1 += 16; - pVect2 += 16; - qty -= 16; - } - while (qty > 0) { - rx = _mm_load_si128((__m128i*) pVect1); - ry = _mm_load_si128((__m128i*) pVect2); - cos = _mm_add_ps(cos, _mm_mul32_pi8(rx, ry)); - p1Len = _mm_add_ps(p1Len, _mm_mul32_pi8(rx, rx)); - p2Len = _mm_add_ps(p2Len, _mm_mul32_pi8(ry, ry)); - pVect1 += 4; - pVect2 += 4; - qty -= 4; - } - cos = _mm_hadd_ps(_mm_hadd_ps(cos, cos), cos); - p1Len = _mm_hadd_ps(_mm_hadd_ps(p1Len, p1Len), p1Len); - p2Len = _mm_hadd_ps(_mm_hadd_ps(p2Len, p2Len), p2Len); - float norm1 = p1Len.m128_f32[0]; - float norm2 = p2Len.m128_f32[0]; - - static const float eps = numeric_limits::min() * 2; - - if (norm1 < eps) { /* - * This shouldn't normally happen for this space, but - * if it does, we don't want to get NANs - */ - if (norm2 < eps) { - return 1; - } - return 0; - } - /* - * Sometimes due to rounding errors, we get values > 1 or < -1. - * This throws off other functions that use scalar product, e.g., acos - */ - return max(float(-1), - min(float(1), cos.m128_f32[0] / sqrt(norm1) / sqrt(norm2))); - } - - static float NormScalarProductSIMD(const float* pVect1, const float* pVect2, - uint32_t qty) { - // Didn't get significant performance gain compared with 128bit version. - static const float eps = numeric_limits::min() * 2; - - if (Avx2SupportedCPU) { - uint32_t qty8 = qty / 8; - - const float* pEnd1 = pVect1 + 8 * qty8; - const float* pEnd2 = pVect1 + qty; - - __m256 v1, v2; - __m256 sum_prod = _mm256_set_ps(0, 0, 0, 0, 0, 0, 0, 0); - __m256 sum_square1 = sum_prod; - __m256 sum_square2 = sum_prod; - - while (pVect1 < pEnd1) { - v1 = _mm256_loadu_ps(pVect1); - pVect1 += 8; - v2 = _mm256_loadu_ps(pVect2); - pVect2 += 8; - sum_prod = _mm256_add_ps(sum_prod, _mm256_mul_ps(v1, v2)); - sum_square1 = _mm256_add_ps(sum_square1, _mm256_mul_ps(v1, v1)); - sum_square2 = _mm256_add_ps(sum_square2, _mm256_mul_ps(v2, v2)); - } - - float PORTABLE_ALIGN16 TmpResProd[8]; - float PORTABLE_ALIGN16 TmpResSquare1[8]; - float PORTABLE_ALIGN16 TmpResSquare2[8]; - - _mm256_store_ps(TmpResProd, sum_prod); - _mm256_store_ps(TmpResSquare1, sum_square1); - _mm256_store_ps(TmpResSquare2, sum_square2); - - float sum = 0.0f; - float norm1 = 0.0f; - float norm2 = 0.0f; - for (uint32_t i = 0; i < 8; ++i) { - sum += TmpResProd[i]; - norm1 += TmpResSquare1[i]; - norm2 += TmpResSquare2[i]; - } - - while (pVect1 < pEnd2) { - sum += (*pVect1) * (*pVect2); - norm1 += (*pVect1) * (*pVect1); - norm2 += (*pVect2) * (*pVect2); - - ++pVect1; - ++pVect2; - } - - if (norm1 < eps) { - return norm2 < eps ? 1.0f : 0.0f; - } - - return max(float(-1), min(float(1), sum / sqrt(norm1) / sqrt(norm2))); - } - - __m128 v1, v2; - __m128 sum_prod = _mm_set1_ps(0); - __m128 sum_square1 = sum_prod; - __m128 sum_square2 = sum_prod; - - while (qty >= 4) { - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - sum_square1 = _mm_add_ps(sum_square1, _mm_mul_ps(v1, v1)); - sum_square2 = _mm_add_ps(sum_square2, _mm_mul_ps(v2, v2)); - - qty -= 4; - } - - float sum = sum_prod.m128_f32[0] + sum_prod.m128_f32[1] + - sum_prod.m128_f32[2] + sum_prod.m128_f32[3]; - float norm1 = sum_square1.m128_f32[0] + sum_square1.m128_f32[1] + - sum_square1.m128_f32[2] + sum_square1.m128_f32[3]; - float norm2 = sum_square2.m128_f32[0] + sum_square2.m128_f32[1] + - sum_square2.m128_f32[2] + sum_square2.m128_f32[3]; - - if (norm1 < eps) { - return norm2 < eps ? 1.0f : 0.0f; - } - - return max(float(-1), min(float(1), sum / sqrt(norm1) / sqrt(norm2))); - } - - static float NormScalarProductSIMD2(const float* pVect1, const float* pVect2, - uint32_t qty) { - return NormScalarProductSIMD(pVect1, pVect2, qty); - } - - template - static float CosineSimilarity2(const T* p1, const T* p2, uint32_t qty) { - return std::max(0.0f, 1.0f - NormScalarProductSIMD2(p1, p2, qty)); - } - - // static template float CosineSimilarity2<__int8>(const __int8* pVect1, - // const __int8* pVect2, size_t qty); - - // static template float CosineSimilarity2(const float* pVect1, - // const float* pVect2, size_t qty); - - template - static void CosineSimilarityNormalize(T* pVector, uint32_t qty) { - T sum = 0; - for (uint32_t i = 0; i < qty; ++i) { - sum += pVector[i] * pVector[i]; - } - sum = 1 / sqrt(sum); - if (sum == 0) { - sum = numeric_limits::min(); - } - for (uint32_t i = 0; i < qty; ++i) { - pVector[i] *= sum; - } - } - - // template static void CosineSimilarityNormalize(float* pVector, - // size_t qty); - // template static void CosineSimilarityNormalize(double* pVector, - // size_t qty); - - template<> - void CosineSimilarityNormalize(__int8* pVector, uint32_t qty) { - throw std::runtime_error( - "For int8 type vector, you can not use cosine distance!"); - } - - template<> - void CosineSimilarityNormalize(__int16* pVector, uint32_t qty) { - throw std::runtime_error( - "For int16 type vector, you can not use cosine distance!"); - } - - template<> - void CosineSimilarityNormalize(int* pVector, uint32_t qty) { - throw std::runtime_error( - "For int type vector, you can not use cosine distance!"); - } -} // namespace diskann -#endif \ No newline at end of file diff --git a/thirdparty/DiskANN/include/distance.h b/thirdparty/DiskANN/include/distance.h index 9bc26f585..a96b15218 100644 --- a/thirdparty/DiskANN/include/distance.h +++ b/thirdparty/DiskANN/include/distance.h @@ -1,142 +1,16 @@ #pragma once -#include "windows_customizations.h" - +#include "knowhere/utils/FaissHookFvec.h" +#include "utils.h" namespace diskann { - - template - class Distance { - public: - virtual float compare(const T *a, const T *b, uint32_t length) const = 0; - virtual ~Distance() { - } - }; - - class DistanceCosineInt8 : public Distance { - public: - DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, - uint32_t length) const; - }; - - class DistanceL2Int8 : public Distance { - public: - DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, - uint32_t size) const; - }; - - // AVX implementations. Borrowed from HNSW code. - class AVXDistanceL2Int8 : public Distance { - public: - DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, - uint32_t length) const; - }; - - // Slow implementations of the distance functions to get diskann to - // work in pre-AVX machines. Performance here is not a concern, so we are - // using the simplest possible implementation. - template - class SlowDistanceL2Int : public Distance { - public: - // Implementing here because this is a template function - DISKANN_DLLEXPORT virtual float compare(const T *a, const T *b, - uint32_t length) const { - uint32_t result = 0; - for (uint32_t i = 0; i < length; i++) { - result += ((int32_t)((int16_t) a[i] - (int16_t) b[i])) * - ((int32_t)((int16_t) a[i] - (int16_t) b[i])); - } - return (float) result; - } - }; - - - class DistanceCosineFloat : public Distance { - public: - DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, - uint32_t length) const; - }; - - - class DistanceL2Float : public Distance { - public: -#ifdef _WINDOWS - DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, - uint32_t size) const; -#else - DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, - uint32_t size) const - __attribute__((hot)); -#endif - }; - - class AVXDistanceL2Float : public Distance { - public: - DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, - uint32_t length) const; - }; - - class SlowDistanceL2Float : public Distance { - public: - DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, - uint32_t length) const; - }; - - - class SlowDistanceCosineUInt8 : public Distance { - public: - DISKANN_DLLEXPORT virtual float compare(const uint8_t *a, const uint8_t *b, - uint32_t length) const; - }; - - - class DistanceL2UInt8 : public Distance { - public: - DISKANN_DLLEXPORT virtual float compare(const uint8_t *a, const uint8_t *b, - uint32_t size) const; - }; - - - template - class DistanceInnerProduct : public Distance { - public: - float inner_product(const T *a, const T *b, unsigned size) const; - float compare(const T *a, const T *b, unsigned size) const { - // since we use normally minimization objective for distance - // comparisons, we are returning 1/x. - float result = inner_product(a, b, size); - // if (result < 0) - // return std::numeric_limits::max(); - // else - return -result; - } - }; - - template - class DistanceFastL2 - : public DistanceInnerProduct { // currently defined only for float. - // templated for future use. - public: - float norm(const T *a, unsigned size) const; - float compare(const T *a, const T *b, float norm, unsigned size) const; - }; - - - class AVXDistanceInnerProductFloat : public Distance { - public: - DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, - uint32_t length) const; - }; - - class AVXNormalizedCosineDistanceFloat : public Distance { - private: - AVXDistanceInnerProductFloat _innerProduct; - - public: - DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, - uint32_t length) const { - // Inner product returns negative values to indicate distance. - // This will ensure that cosine is between -1 and 1. - return 1.0f + _innerProduct.compare(a, b, length); - } - }; - -} // namespace diskann + template + using DISTFUN = T (*)(const T *, const T *, size_t); + + template + DISTFUN + get_distance_function(Metric m); + + template + float + norm_l2sqr(const T * x, size_t size); + +} // namespace diskann \ No newline at end of file diff --git a/thirdparty/DiskANN/include/index.h b/thirdparty/DiskANN/include/index.h index cc5146552..b494515e8 100644 --- a/thirdparty/DiskANN/include/index.h +++ b/thirdparty/DiskANN/include/index.h @@ -358,7 +358,7 @@ namespace diskann { size_t _max_points = 0; // total number of points in given data set size_t _num_frozen_pts = 0; bool _has_built = false; - Distance *_distance = nullptr; + DISTFUN _distance = nullptr; unsigned _width = 0; unsigned _ep = 0; size_t _max_range_of_loaded_graph = 0; diff --git a/thirdparty/DiskANN/include/percentile_stats.h b/thirdparty/DiskANN/include/percentile_stats.h index 0d0105481..08b3a6fe1 100644 --- a/thirdparty/DiskANN/include/percentile_stats.h +++ b/thirdparty/DiskANN/include/percentile_stats.h @@ -13,7 +13,6 @@ #include #include -#include "distance.h" #include "parameters.h" namespace diskann { diff --git a/thirdparty/DiskANN/include/pq_flash_index.h b/thirdparty/DiskANN/include/pq_flash_index.h index 8eee3f0fb..77083781d 100644 --- a/thirdparty/DiskANN/include/pq_flash_index.h +++ b/thirdparty/DiskANN/include/pq_flash_index.h @@ -14,6 +14,7 @@ #include "aligned_file_reader.h" #include "concurrent_queue.h" +#include "distance.h" #include "neighbor.h" #include "parameters.h" #include "percentile_stats.h" @@ -184,9 +185,9 @@ namespace diskann { _u64 n_chunks; FixedChunkPQTable pq_table; - // distance comparator - std::shared_ptr> dist_cmp; - std::shared_ptr> dist_cmp_float; + // distance function + DISTFUN dist_cmp; + DISTFUN dist_cmp_float; // for very large datasets: we use PQ even for the disk resident index bool use_disk_index_pq = false; diff --git a/thirdparty/DiskANN/include/utils.h b/thirdparty/DiskANN/include/utils.h index 2f148210d..ea7b4251a 100644 --- a/thirdparty/DiskANN/include/utils.h +++ b/thirdparty/DiskANN/include/utils.h @@ -33,7 +33,6 @@ typedef HANDLE FileHandle; typedef int FileHandle; #endif -#include "distance.h" #include "utils.h" #include "logger.h" #include "cached_io.h" @@ -162,7 +161,7 @@ inline int delete_file(const std::string& fileName) { namespace diskann { static const size_t MAX_SIZE_OF_STREAMBUF = 2LL * 1024 * 1024 * 1024; - enum Metric { L2 = 0, INNER_PRODUCT = 1, COSINE = 2, FAST_L2 = 3, PQ = 4 }; + enum Metric { L2 = 0, INNER_PRODUCT = 1, COSINE = 2}; inline void alloc_aligned(void** ptr, size_t size, size_t align) { *ptr = nullptr; @@ -849,9 +848,6 @@ namespace diskann { DISKANN_DLLEXPORT void normalize_data_file(const std::string& inFileName, const std::string& outFileName); - template - Distance* get_distance_function(Metric m); - inline std::string get_pq_pivots_filename(const std::string& prefix) { return prefix + "_pq_pivots.bin"; } diff --git a/thirdparty/DiskANN/src/distance.cpp b/thirdparty/DiskANN/src/distance.cpp index de3d8c748..9dc3b50bd 100644 --- a/thirdparty/DiskANN/src/distance.cpp +++ b/thirdparty/DiskANN/src/distance.cpp @@ -1,535 +1,79 @@ -// TODO -// CHECK COSINE ON LINUX - -#ifdef _WINDOWS -#include -#include -#include -#include -#else -#include -#endif - -#include "simd_utils.h" -#include -#include - +#pragma once +#include #include "distance.h" -#include "logger.h" -#include "ann_exception.h" - namespace diskann { - -namespace { -#define ALIGNED(x) __attribute__((aligned(x))) - - // reads 0 <= d < 4 floats as __m128 - inline __m128 masked_read(int d, const float* x) { - ALIGNED(16) float buf[4] = {0, 0, 0, 0}; - switch (d) { - case 3: - buf[2] = x[2]; - case 2: - buf[1] = x[1]; - case 1: - buf[0] = x[0]; - } - return _mm_load_ps(buf); - // cannot use AVX2 _mm_mask_set1_epi32 - } -} - - // Cosine similarity. - float DistanceCosineInt8::compare(const int8_t *a, const int8_t *b, - uint32_t length) const { -#ifdef _WINDOWS - return diskann::CosineSimilarity2(a, b, length); -#else - int magA = 0, magB = 0, scalarProduct = 0; - for (uint32_t i = 0; i < length; i++) { - magA += ((int32_t) a[i]) * ((int32_t) a[i]); - magB += ((int32_t) b[i]) * ((int32_t) b[i]); - scalarProduct += ((int32_t) a[i]) * ((int32_t) b[i]); - } - // similarity == 1-cosine distance - return 1.0f - (float) (scalarProduct / (sqrt(magA) * sqrt(magB))); -#endif - } - - float DistanceCosineFloat::compare(const float *a, const float *b, - uint32_t length) const { -#ifdef _WINDOWS - return diskann::CosineSimilarity2(a, b, length); -#else - float magA = 0, magB = 0, scalarProduct = 0; - for (uint32_t i = 0; i < length; i++) { - magA += (a[i]) * (a[i]); - magB += (b[i]) * (b[i]); - scalarProduct += (a[i]) * (b[i]); - } - // similarity == 1-cosine distance - return 1.0f - (scalarProduct / (sqrt(magA) * sqrt(magB))); -#endif - } - - float SlowDistanceCosineUInt8::compare(const uint8_t *a, const uint8_t *b, - uint32_t length) const { - int magA = 0, magB = 0, scalarProduct = 0; - for (uint32_t i = 0; i < length; i++) { - magA += ((uint32_t) a[i]) * ((uint32_t) a[i]); - magB += ((uint32_t) b[i]) * ((uint32_t) b[i]); - scalarProduct += ((uint32_t) a[i]) * ((uint32_t) b[i]); - } - // similarity == 1-cosine distance - return 1.0f - (float) (scalarProduct / (sqrt(magA) * sqrt(magB))); - } - - // L2 distance functions. - float DistanceL2Int8::compare(const int8_t *a, const int8_t *b, - uint32_t size) const { - int32_t result = 0; - -#ifdef _WINDOWS -#ifdef USE_AVX2 - __m256 r = _mm256_setzero_ps(); - char * pX = (char *) a, *pY = (char *) b; - while (size >= 32) { - __m256i r1 = _mm256_subs_epi8(_mm256_loadu_si256((__m256i *) pX), - _mm256_loadu_si256((__m256i *) pY)); - r = _mm256_add_ps(r, _mm256_mul_epi8(r1, r1)); - pX += 32; - pY += 32; - size -= 32; - } - while (size > 0) { - __m128i r2 = _mm_subs_epi8(_mm_loadu_si128((__m128i *) pX), - _mm_loadu_si128((__m128i *) pY)); - r = _mm256_add_ps(r, _mm256_mul32_pi8(r2, r2)); - pX += 4; - pY += 4; - size -= 4; - } - r = _mm256_hadd_ps(_mm256_hadd_ps(r, r), r); - return r.m256_f32[0] + r.m256_f32[4]; -#else -#pragma omp simd reduction(+ : result) aligned(a, b : 8) - for (_s32 i = 0; i < (_s32) size; i++) { - result += ((int32_t)((int16_t) a[i] - (int16_t) b[i])) * - ((int32_t)((int16_t) a[i] - (int16_t) b[i])); - } - return (float) result; -#endif -#else -#pragma omp simd reduction(+ : result) aligned(a, b : 8) - for (int32_t i = 0; i < (int32_t) size; i++) { - result += ((int32_t)((int16_t) a[i] - (int16_t) b[i])) * - ((int32_t)((int16_t) a[i] - (int16_t) b[i])); - } - return (float) result; -#endif - } - - float DistanceL2UInt8::compare(const uint8_t *a, const uint8_t *b, - uint32_t size) const { - uint32_t result = 0; -#ifndef _WINDOWS -#pragma omp simd reduction(+ : result) aligned(a, b : 8) -#endif - for (int32_t i = 0; i < (int32_t) size; i++) { - result += ((int32_t)((int16_t) a[i] - (int16_t) b[i])) * - ((int32_t)((int16_t) a[i] - (int16_t) b[i])); - } - return (float) result; - } - -#ifndef _WINDOWS - float DistanceL2Float::compare(const float *x, const float *y, - uint32_t d) const { - __m256 msum1 = _mm256_setzero_ps(); - - while (d >= 8) { - __m256 mx = _mm256_loadu_ps(x); - x += 8; - __m256 my = _mm256_loadu_ps(y); - y += 8; - const __m256 a_m_b1 = _mm256_sub_ps(mx, my); - msum1 = _mm256_add_ps(msum1, _mm256_mul_ps(a_m_b1, a_m_b1)); - d -= 8; - } - - __m128 msum2 = _mm256_extractf128_ps(msum1, 1); - msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0)); - - if (d >= 4) { - __m128 mx = _mm_loadu_ps(x); - x += 4; - __m128 my = _mm_loadu_ps(y); - y += 4; - const __m128 a_m_b1 = _mm_sub_ps(mx, my); - msum2 = _mm_add_ps(msum2, _mm_mul_ps(a_m_b1, a_m_b1)); - d -= 4; - } - - if (d > 0) { - __m128 mx = masked_read(d, x); - __m128 my = masked_read(d, y); - __m128 a_m_b1 = _mm_sub_ps(mx, my); - msum2 = _mm_add_ps(msum2, _mm_mul_ps(a_m_b1, a_m_b1)); - } - - msum2 = _mm_hadd_ps(msum2, msum2); - msum2 = _mm_hadd_ps(msum2, msum2); - return _mm_cvtss_f32(msum2); -#else -#ifndef _WINDOWS -#pragma omp simd reduction(+ : result) aligned(x, y : 32) -#endif - for (int32_t i = 0; i < (int32_t) d; i++) { - result += (x[i] - y[i]) * (x[i] - y[i]); - } - return result; -#endif - } - - float SlowDistanceL2Float::compare(const float *a, const float *b, - uint32_t length) const { - float result = 0.0f; - for (uint32_t i = 0; i < length; i++) { - result += (a[i] - b[i]) * (a[i] - b[i]); - } - return result; - } - -#ifdef _WINDOWS - float AVXDistanceL2Int8::compare(const int8_t *a, const int8_t *b, - uint32_t length) const { - __m128 r = _mm_setzero_ps(); - __m128i r1; - while (length >= 16) { - r1 = _mm_subs_epi8(_mm_load_si128((__m128i *) a), - _mm_load_si128((__m128i *) b)); - r = _mm_add_ps(r, _mm_mul_epi8(r1)); - a += 16; - b += 16; - length -= 16; - } - r = _mm_hadd_ps(_mm_hadd_ps(r, r), r); - float res = r.m128_f32[0]; - - if (length >= 8) { - __m128 r2 = _mm_setzero_ps(); - __m128i r3 = _mm_subs_epi8(_mm_load_si128((__m128i *) (a - 8)), - _mm_load_si128((__m128i *) (b - 8))); - r2 = _mm_add_ps(r2, _mm_mulhi_epi8(r3)); - a += 8; - b += 8; - length -= 8; - r2 = _mm_hadd_ps(_mm_hadd_ps(r2, r2), r2); - res += r2.m128_f32[0]; - } - - if (length >= 4) { - __m128 r2 = _mm_setzero_ps(); - __m128i r3 = _mm_subs_epi8(_mm_load_si128((__m128i *) (a - 12)), - _mm_load_si128((__m128i *) (b - 12))); - r2 = _mm_add_ps(r2, _mm_mulhi_epi8_shift32(r3)); - res += r2.m128_f32[0] + r2.m128_f32[1]; - } - - return res; - } - - float AVXDistanceL2Float::compare(const float *a, const float *b, - uint32_t length) const { - __m128 diff, v1, v2; - __m128 sum = _mm_set1_ps(0); - - while (length >= 4) { - v1 = _mm_loadu_ps(a); - a += 4; - v2 = _mm_loadu_ps(b); - b += 4; - diff = _mm_sub_ps(v1, v2); - sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); - length -= 4; - } - - return sum.m128_f32[0] + sum.m128_f32[1] + sum.m128_f32[2] + - sum.m128_f32[3]; - } -#else - float AVXDistanceL2Int8::compare(const int8_t *, const int8_t *, - uint32_t) const { - return 0; - } - float AVXDistanceL2Float::compare(const float *, const float *, - uint32_t) const { - return 0; - } -#endif - + // Get the right distance function for the given metric. template - float DistanceInnerProduct::inner_product(const T *a, const T *b, - unsigned size) const { - if (!std::is_floating_point::value) { - diskann::cerr << "ERROR: Inner Product only defined for float currently." - << std::endl; - throw diskann::ANNException( - "ERROR: Inner Product only defined for float currently.", -1, - __FUNCSIG__, __FILE__, __LINE__); - } - - float result = 0; - -#ifdef __GNUC__ -#ifdef USE_AVX2 -#define AVX_DOT(addr1, addr2, dest, tmp1, tmp2) \ - tmp1 = _mm256_loadu_ps(addr1); \ - tmp2 = _mm256_loadu_ps(addr2); \ - tmp1 = _mm256_mul_ps(tmp1, tmp2); \ - dest = _mm256_add_ps(dest, tmp1); - - __m256 sum; - __m256 l0, l1; - __m256 r0, r1; - unsigned D = (size + 7) & ~7U; - unsigned DR = D % 16; - unsigned DD = D - DR; - const float *l = (float *) a; - const float *r = (float *) b; - const float *e_l = l + DD; - const float *e_r = r + DD; - float unpack[8] __attribute__((aligned(32))) = {0, 0, 0, 0, 0, 0, 0, 0}; - - sum = _mm256_loadu_ps(unpack); - if (DR) { - AVX_DOT(e_l, e_r, sum, l0, r0); - } - - for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) { - AVX_DOT(l, r, sum, l0, r0); - AVX_DOT(l + 8, r + 8, sum, l1, r1); - } - _mm256_storeu_ps(unpack, sum); - result = unpack[0] + unpack[1] + unpack[2] + unpack[3] + unpack[4] + - unpack[5] + unpack[6] + unpack[7]; - -#else -#ifdef __SSE2__ -#define SSE_DOT(addr1, addr2, dest, tmp1, tmp2) \ - tmp1 = _mm128_loadu_ps(addr1); \ - tmp2 = _mm128_loadu_ps(addr2); \ - tmp1 = _mm128_mul_ps(tmp1, tmp2); \ - dest = _mm128_add_ps(dest, tmp1); - __m128 sum; - __m128 l0, l1, l2, l3; - __m128 r0, r1, r2, r3; - unsigned D = (size + 3) & ~3U; - unsigned DR = D % 16; - unsigned DD = D - DR; - const float *l = a; - const float *r = b; - const float *e_l = l + DD; - const float *e_r = r + DD; - float unpack[4] __attribute__((aligned(16))) = {0, 0, 0, 0}; - - sum = _mm_load_ps(unpack); - switch (DR) { - case 12: - SSE_DOT(e_l + 8, e_r + 8, sum, l2, r2); - case 8: - SSE_DOT(e_l + 4, e_r + 4, sum, l1, r1); - case 4: - SSE_DOT(e_l, e_r, sum, l0, r0); - default: - break; - } - for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) { - SSE_DOT(l, r, sum, l0, r0); - SSE_DOT(l + 4, r + 4, sum, l1, r1); - SSE_DOT(l + 8, r + 8, sum, l2, r2); - SSE_DOT(l + 12, r + 12, sum, l3, r3); - } - _mm_storeu_ps(unpack, sum); - result += unpack[0] + unpack[1] + unpack[2] + unpack[3]; -#else - - float dot0, dot1, dot2, dot3; - const float *last = a + size; - const float *unroll_group = last - 3; - - /* Process 4 items with each loop for efficiency. */ - while (a < unroll_group) { - dot0 = a[0] * b[0]; - dot1 = a[1] * b[1]; - dot2 = a[2] * b[2]; - dot3 = a[3] * b[3]; - result += dot0 + dot1 + dot2 + dot3; - a += 4; - b += 4; - } - /* Process last 0-3 pixels. Not needed for standard vector lengths. */ - while (a < last) { - result += *a++ * *b++; - } -#endif -#endif -#endif - return result; - } - + DISTFUN get_distance_function(diskann::Metric m) { + if (m == diskann::Metric::L2) { + return [](const T* x, const T* y, size_t size) -> float { + float res = 0; + for (auto i = 0; i < size; i++) { + res += ((float) x[i] - (float) y[i]) * ((float) x[i] - (float) y[i]); + } + return res; + }; + } else if (m == diskann::Metric::INNER_PRODUCT) { + return [](const T* x, const T* y, size_t size) -> float { + float res = 0; + for (auto i = 0; i < size; i++) { + res += (float) x[i] * (float) y[i]; + } + return -res; + }; + } else { + std::stringstream stream; + stream << "Only L2 and inner product supported as for now. "; + LOG(ERROR) << stream.str(); + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, + __LINE__); + } + } + + template<> + DISTFUN get_distance_function(diskann::Metric m) { + if (m == diskann::Metric::L2) { + return faiss::fvec_L2sqr; + } else if (m == diskann::Metric::INNER_PRODUCT) { + return [](const float* x, const float* y, size_t size) -> float { + return (-1.0) * faiss::fvec_inner_product(x, y, size); + }; + } else if (m == diskann::Metric::COSINE) { + return [](const float* x, const float* y, size_t size) -> float { + return 1.0 - faiss::fvec_inner_product(x, y, size); + }; + } else { + std::stringstream stream; + stream << "Only L2, cosine, and inner product supported for floating " + "point vectors as of now. "; + LOG(ERROR) << stream.str(); + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, + __LINE__); + } + } + + // get vector sqr norm template - float DistanceFastL2::compare(const T *a, const T *b, float norm, - unsigned size) const { - float result = -2 * DistanceInnerProduct::inner_product(a, b, size); - result += norm; - return result; - } - - template - float DistanceFastL2::norm(const T *a, unsigned size) const { - if (!std::is_floating_point::value) { - diskann::cerr << "ERROR: FastL2 only defined for float currently." - << std::endl; - throw diskann::ANNException( - "ERROR: FastL2 only defined for float currently.", -1, __FUNCSIG__, - __FILE__, __LINE__); - } - float result = 0; -#ifdef __GNUC__ -#ifdef __AVX__ -#define AVX_L2NORM(addr, dest, tmp) \ - tmp = _mm256_loadu_ps(addr); \ - tmp = _mm256_mul_ps(tmp, tmp); \ - dest = _mm256_add_ps(dest, tmp); - - __m256 sum; - __m256 l0, l1; - unsigned D = (size + 7) & ~7U; - unsigned DR = D % 16; - unsigned DD = D - DR; - const float *l = (float *) a; - const float *e_l = l + DD; - float unpack[8] __attribute__((aligned(32))) = {0, 0, 0, 0, 0, 0, 0, 0}; - - sum = _mm256_loadu_ps(unpack); - if (DR) { - AVX_L2NORM(e_l, sum, l0); - } - for (unsigned i = 0; i < DD; i += 16, l += 16) { - AVX_L2NORM(l, sum, l0); - AVX_L2NORM(l + 8, sum, l1); - } - _mm256_storeu_ps(unpack, sum); - result = unpack[0] + unpack[1] + unpack[2] + unpack[3] + unpack[4] + - unpack[5] + unpack[6] + unpack[7]; -#else -#ifdef __SSE2__ -#define SSE_L2NORM(addr, dest, tmp) \ - tmp = _mm128_loadu_ps(addr); \ - tmp = _mm128_mul_ps(tmp, tmp); \ - dest = _mm128_add_ps(dest, tmp); - - __m128 sum; - __m128 l0, l1, l2, l3; - unsigned D = (size + 3) & ~3U; - unsigned DR = D % 16; - unsigned DD = D - DR; - const float *l = a; - const float *e_l = l + DD; - float unpack[4] __attribute__((aligned(16))) = {0, 0, 0, 0}; - - sum = _mm_load_ps(unpack); - switch (DR) { - case 12: - SSE_L2NORM(e_l + 8, sum, l2); - case 8: - SSE_L2NORM(e_l + 4, sum, l1); - case 4: - SSE_L2NORM(e_l, sum, l0); - default: - break; - } - for (unsigned i = 0; i < DD; i += 16, l += 16) { - SSE_L2NORM(l, sum, l0); - SSE_L2NORM(l + 4, sum, l1); - SSE_L2NORM(l + 8, sum, l2); - SSE_L2NORM(l + 12, sum, l3); - } - _mm_storeu_ps(unpack, sum); - result += unpack[0] + unpack[1] + unpack[2] + unpack[3]; -#else - float dot0, dot1, dot2, dot3; - const float *last = a + size; - const float *unroll_group = last - 3; - - /* Process 4 items with each loop for efficiency. */ - while (a < unroll_group) { - dot0 = a[0] * a[0]; - dot1 = a[1] * a[1]; - dot2 = a[2] * a[2]; - dot3 = a[3] * a[3]; - result += dot0 + dot1 + dot2 + dot3; - a += 4; - } - /* Process last 0-3 pixels. Not needed for standard vector lengths. */ - while (a < last) { - result += (*a) * (*a); - a++; - } -#endif -#endif -#endif - return result; - } - - float AVXDistanceInnerProductFloat::compare(const float *a, const float *b, - uint32_t size) const { - float result = 0.0f; -#define AVX_DOT(addr1, addr2, dest, tmp1, tmp2) \ - tmp1 = _mm256_loadu_ps(addr1); \ - tmp2 = _mm256_loadu_ps(addr2); \ - tmp1 = _mm256_mul_ps(tmp1, tmp2); \ - dest = _mm256_add_ps(dest, tmp1); - - __m256 sum; - __m256 l0, l1; - __m256 r0, r1; - unsigned D = (size + 7) & ~7U; - unsigned DR = D % 16; - unsigned DD = D - DR; - const float *l = (float *) a; - const float *r = (float *) b; - const float *e_l = l + DD; - const float *e_r = r + DD; -#ifndef _WINDOWS - float unpack[8] __attribute__((aligned(32))) = {0, 0, 0, 0, 0, 0, 0, 0}; -#else - __declspec(align(32)) float unpack[8] = {0, 0, 0, 0, 0, 0, 0, 0}; -#endif - - sum = _mm256_loadu_ps(unpack); - if (DR) { - AVX_DOT(e_l, e_r, sum, l0, r0); - } - - for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) { - AVX_DOT(l, r, sum, l0, r0); - AVX_DOT(l + 8, r + 8, sum, l1, r1); + float norm_l2sqr(const T* a, size_t size) { + if constexpr (std::is_floating_point::value) { + return faiss::fvec_norm_L2sqr(a, size); + } else { + float res = 0; + for (auto i = 0; i < size; i++) { + res += (float) a[i] * (float) a[i]; + } + return res; } - _mm256_storeu_ps(unpack, sum); - result = unpack[0] + unpack[1] + unpack[2] + unpack[3] + unpack[4] + - unpack[5] + unpack[6] + unpack[7]; - - return -result; } - template DISKANN_DLLEXPORT class DistanceInnerProduct; - template DISKANN_DLLEXPORT class DistanceInnerProduct; - template DISKANN_DLLEXPORT class DistanceInnerProduct; - - template DISKANN_DLLEXPORT class DistanceFastL2; - template DISKANN_DLLEXPORT class DistanceFastL2; - template DISKANN_DLLEXPORT class DistanceFastL2; + template DISKANN_DLLEXPORT DISTFUN get_distance_function( + diskann::Metric m); + template DISKANN_DLLEXPORT DISTFUN get_distance_function( + diskann::Metric m); + template DISKANN_DLLEXPORT DISTFUN get_distance_function( + diskann::Metric m); -} // namespace diskann + template DISKANN_DLLEXPORT float norm_l2sqr(const float*, size_t); + template DISKANN_DLLEXPORT float norm_l2sqr(const uint8_t*, size_t); + template DISKANN_DLLEXPORT float norm_l2sqr(const int8_t*, size_t); +} // namespace diskann \ No newline at end of file diff --git a/thirdparty/DiskANN/src/index.cpp b/thirdparty/DiskANN/src/index.cpp index af9c75adf..2f91958f7 100644 --- a/thirdparty/DiskANN/src/index.cpp +++ b/thirdparty/DiskANN/src/index.cpp @@ -42,6 +42,7 @@ #include #endif #include "index.h" +#include "knowhere/utils/FaissHookFvec.h" #define MAX_POINTS_FOR_USING_BITSET 10000000 @@ -278,16 +279,7 @@ namespace diskann { _in_graph.reserve(_max_points + _num_frozen_pts); _in_graph.resize(_max_points + _num_frozen_pts); } - - if (m == diskann::Metric::COSINE && std::is_floating_point::value) { - // This is safe because T is float inside the if block. - this->_distance = (Distance *) new AVXNormalizedCosineDistanceFloat(); - this->_normalize_vecs = true; - std::cout<<"Normalizing vectors and using L2 for cosine AVXNormalizedCosineDistanceFloat()." << std::endl; - // std::cout << "Need to add functionality for COSINE metric" << std::endl; - } else { - this->_distance = get_distance_function(m); - } + this->_distance = get_distance_function(m); _locks = std::vector(_max_points + _num_frozen_pts); @@ -311,10 +303,6 @@ namespace diskann { LockGuard lg(lock); } - if (this->_distance != nullptr) { - delete this->_distance; - this->_distance = nullptr; - } if (this->_data != nullptr) { aligned_free(this->_data); this->_data = nullptr; @@ -869,7 +857,7 @@ namespace diskann { #pragma omp parallel for schedule(static, 65536) for (_s64 i = 0; i < (_s64) _nd; i++) { // extract point and distance reference - float & dist = distances[i]; + float &dist = distances[i]; const T *cur_vec = _data + (i * (size_t) _aligned_dim); dist = 0; float diff = 0; @@ -954,8 +942,8 @@ namespace diskann { __FILE__, __LINE__); } nn = Neighbor(id, - _distance->compare(_data + _aligned_dim * (size_t) id, - node_coords, (unsigned) _aligned_dim), + _distance(_data + _aligned_dim * (size_t) id, node_coords, + (size_t) _aligned_dim), true); if (fast_iterate) { if (inserted_into_pool_bs[id] == 0) { @@ -1040,9 +1028,9 @@ namespace diskann { } cmps++; - float dist = _distance->compare(node_coords, - _data + _aligned_dim * (size_t) id, - (unsigned) _aligned_dim); + float dist = + _distance(node_coords, _data + _aligned_dim * (size_t) id, + (size_t) _aligned_dim); if (dist >= best_L_nodes[l - 1].distance && (l == Lsize)) continue; @@ -1164,9 +1152,9 @@ namespace diskann { for (unsigned t = start + 1; t < pool.size() && t < maxc; t++) { if (occlude_factor[t] > alpha) continue; - float djk = _distance->compare( - _data + _aligned_dim * (size_t) pool[t].id, - _data + _aligned_dim * (size_t) p.id, (unsigned) _aligned_dim); + float djk = _distance(_data + _aligned_dim * (size_t) pool[t].id, + _data + _aligned_dim * (size_t) p.id, + (size_t) _aligned_dim); if (_dist_metric == diskann::Metric::L2 || _dist_metric == diskann::Metric::COSINE) { occlude_factor[t] = @@ -1191,7 +1179,6 @@ namespace diskann { } } - template void Index::prune_neighbors(const unsigned location, std::vector &pool, @@ -1203,7 +1190,6 @@ namespace diskann { void Index::prune_neighbors(const unsigned location, std::vector &pool, const _u32 range, const _u32 max_candidate_size, const float alpha, std::vector &pruned_list) { - if (pool.size() == 0) { std::stringstream ss; ss << "Thread id:" << std::this_thread::get_id() @@ -1336,10 +1322,9 @@ namespace diskann { for (auto cur_nbr : copy_of_neighbors) { if (dummy_visited.find(cur_nbr) == dummy_visited.end() && cur_nbr != des) { - float dist = - _distance->compare(_data + _aligned_dim * (size_t) des, + float dist = _distance(_data + _aligned_dim * (size_t) des, _data + _aligned_dim * (size_t) cur_nbr, - (unsigned) _aligned_dim); + (size_t) _aligned_dim); dummy_pool.emplace_back(Neighbor(cur_nbr, dist, true)); dummy_visited.insert(cur_nbr); } @@ -1383,7 +1368,6 @@ namespace diskann { inter_insert(n, pruned_list, _indexingRange, update_in_graph); } - /* Link(): * The graph creation function. * The graph will be updated periodically in NUM_SYNCS batches @@ -1492,7 +1476,7 @@ namespace diskann { std::vector> pruned_list_vector(round_size); - auto round_num_syncs = num_syncs; + auto round_num_syncs = num_syncs; if(accelerate_build && rnd_no == 0){ round_num_syncs = num_syncs * 0.05; } @@ -1526,10 +1510,9 @@ namespace diskann { if (!_final_graph[node].empty()) for (auto id : _final_graph[node]) { if (visited.find(id) == visited.end() && id != node) { - float dist = - _distance->compare(_data + _aligned_dim * (size_t) node, + float dist = _distance(_data + _aligned_dim * (size_t) node, _data + _aligned_dim * (size_t) id, - (unsigned) _aligned_dim); + (size_t) _aligned_dim); pool.emplace_back(Neighbor(id, dist, true)); visited.insert(id); } @@ -1578,10 +1561,9 @@ namespace diskann { for (auto cur_nbr : _final_graph[node]) { if (dummy_visited.find(cur_nbr) == dummy_visited.end() && cur_nbr != node) { - float dist = - _distance->compare(_data + _aligned_dim * (size_t) node, + float dist = _distance(_data + _aligned_dim * (size_t) node, _data + _aligned_dim * (size_t) cur_nbr, - (unsigned) _aligned_dim); + (size_t) _aligned_dim); dummy_pool.emplace_back(Neighbor(cur_nbr, dist, true)); dummy_visited.insert(cur_nbr); } @@ -1622,10 +1604,10 @@ namespace diskann { #endif if (_nd > 0) { LOG(INFO) << "Completed Pass " << rnd_no << " of data using L=" << L - << " and alpha=" << _indexingAlpha - << ". Stats: " << "search+prune_time=" << total_sync_time - << "s, inter_time=" << total_inter_time - << "s, inter_count=" << total_inter_count; + << " and alpha=" << _indexingAlpha << ". Stats: " + << "search+prune_time=" << total_sync_time + << "s, inter_time=" << total_inter_time + << "s, inter_count=" << total_inter_count; } } @@ -1643,10 +1625,9 @@ namespace diskann { for (auto cur_nbr : _final_graph[node]) { if (dummy_visited.find(cur_nbr) == dummy_visited.end() && cur_nbr != node) { - float dist = - _distance->compare(_data + _aligned_dim * (size_t) node, + float dist = _distance(_data + _aligned_dim * (size_t) node, _data + _aligned_dim * (size_t) cur_nbr, - (unsigned) _aligned_dim); + (size_t) _aligned_dim); dummy_pool.emplace_back(Neighbor(cur_nbr, dist, true)); dummy_visited.insert(cur_nbr); } @@ -1680,10 +1661,9 @@ namespace diskann { for (auto cur_nbr : _final_graph[node]) { if (dummy_visited.find(cur_nbr) == dummy_visited.end() && cur_nbr != node) { - float dist = - _distance->compare(_data + _aligned_dim * (size_t) node, + float dist = _distance(_data + _aligned_dim * (size_t) node, _data + _aligned_dim * (size_t) cur_nbr, - (unsigned) _aligned_dim); + (size_t) _aligned_dim); dummy_pool.emplace_back(Neighbor(cur_nbr, dist, true)); dummy_visited.insert(cur_nbr); } @@ -2058,7 +2038,7 @@ namespace diskann { template T *Index::get_data() { if (_num_frozen_pts > 0) { - T * ret_data = nullptr; + T *ret_data = nullptr; size_t allocSize = ((size_t) _nd) * _aligned_dim * sizeof(T); alloc_aligned(((void **) &ret_data), allocSize, 8 * sizeof(T)); memset(ret_data, 0, allocSize); @@ -2130,7 +2110,7 @@ namespace diskann { int delete_mode) { if (_lazy_done && (!_data_compacted)) { LOG(ERROR) << "Lazy delete requests issued but data not consolidated, " - "cannot proceed with eager deletes."; + "cannot proceed with eager deletes."; return -1; } @@ -2255,12 +2235,12 @@ namespace diskann { } for (auto j : candidate_set) - expanded_nghrs.push_back( - Neighbor(j, - _distance->compare(_data + _aligned_dim * (size_t) ngh, - _data + _aligned_dim * (size_t) j, - (unsigned) _aligned_dim), - true)); + expanded_nghrs.push_back(Neighbor( + j, + _distance((const T *) (_data + _aligned_dim * (size_t) ngh), + _data + _aligned_dim * (size_t) j, + (size_t) _aligned_dim), + true)); std::sort(expanded_nghrs.begin(), expanded_nghrs.end()); occlude_list(expanded_nghrs, alpha, range, maxc, result); @@ -2380,9 +2360,9 @@ namespace diskann { for (auto j : candidate_set) { expanded_nghrs.push_back( Neighbor(j, - _distance->compare(_data + _aligned_dim * i, - _data + _aligned_dim * (size_t) j, - (unsigned) _aligned_dim), + _distance(_data + _aligned_dim * i, + _data + _aligned_dim * (size_t) j, + (size_t) _aligned_dim), true)); } @@ -3034,10 +3014,9 @@ namespace diskann { _neighbor_len = (_width + 1) * sizeof(unsigned); _node_size = _data_len + _neighbor_len; _opt_graph = (char *) malloc(_node_size * _nd); - DistanceFastL2 *dist_fast = (DistanceFastL2 *) _distance; for (unsigned i = 0; i < _nd; i++) { char *cur_node_offset = _opt_graph + i * _node_size; - float cur_norm = dist_fast->norm(_data + i * _aligned_dim, _aligned_dim); + float cur_norm = norm_l2sqr(_data + i * _aligned_dim, _aligned_dim); std::memcpy(cur_node_offset, &cur_norm, sizeof(float)); std::memcpy(cur_node_offset + sizeof(float), _data + i * _aligned_dim, _data_len - sizeof(float)); @@ -3056,8 +3035,8 @@ namespace diskann { template void Index::search_with_optimized_layout(const T *query, size_t K, size_t L, unsigned *indices) { - DistanceFastL2 *dist_fast = (DistanceFastL2 *) _distance; - + auto fast_distance = + get_distance_function(diskann::Metric::INNER_PRODUCT); std::vector retset(L + 1); std::vector init_ids(L); // std::mt19937 rng(rand()); @@ -3098,8 +3077,8 @@ namespace diskann { T * x = (T *) (_opt_graph + _node_size * id); float norm_x = *x; x++; - float dist = - dist_fast->compare(x, query, norm_x, (unsigned) _aligned_dim); + // dist = x * x + 2 * x * query + float dist = norm_x - 2 * fast_distance(x, query, (size_t) _aligned_dim); retset[i] = Neighbor(id, dist, true); flags[id] = true; L++; @@ -3131,7 +3110,7 @@ namespace diskann { float norm = *data; data++; float dist = - dist_fast->compare(query, data, norm, (unsigned) _aligned_dim); + norm - 2 * fast_distance(query, data, (size_t) _aligned_dim); if (dist >= retset[L - 1].distance) continue; Neighbor nn(id, dist, true); diff --git a/thirdparty/DiskANN/src/pq_flash_index.cpp b/thirdparty/DiskANN/src/pq_flash_index.cpp index bb3a674d4..725af478a 100644 --- a/thirdparty/DiskANN/src/pq_flash_index.cpp +++ b/thirdparty/DiskANN/src/pq_flash_index.cpp @@ -19,7 +19,6 @@ #include "timer.h" #include "utils.h" -#include "cosine_similarity.h" #include "tsl/robin_set.h" #ifdef _WINDOWS @@ -93,8 +92,8 @@ namespace diskann { } } - this->dist_cmp.reset(diskann::get_distance_function(m)); - this->dist_cmp_float.reset(diskann::get_distance_function(m)); + this->dist_cmp = diskann::get_distance_function(m); + this->dist_cmp_float = diskann::get_distance_function(m); } @@ -884,9 +883,9 @@ namespace diskann { float best_dist = (std::numeric_limits::max)(); std::vector medoid_dists; for (_u64 cur_m = 0; cur_m < num_medoids; cur_m++) { - float cur_expanded_dist = dist_cmp_float->compare( + float cur_expanded_dist = dist_cmp_float( query_float, centroid_data + aligned_dim * cur_m, - (unsigned) aligned_dim); + (size_t) aligned_dim); if (cur_expanded_dist < best_dist) { best_medoid = medoids[cur_m]; best_dist = cur_expanded_dist; @@ -991,8 +990,8 @@ namespace diskann { T * node_fp_coords_copy = global_cache_iter->second; float cur_expanded_dist; if (!use_disk_index_pq) { - cur_expanded_dist = dist_cmp->compare(query, node_fp_coords_copy, - (unsigned) aligned_dim); + cur_expanded_dist = dist_cmp(query, node_fp_coords_copy, + (size_t) aligned_dim); } else { if (metric == diskann::Metric::INNER_PRODUCT) cur_expanded_dist = disk_pq_table.inner_product( @@ -1079,8 +1078,8 @@ namespace diskann { memcpy(node_fp_coords_copy, node_fp_coords, disk_bytes_per_point); float cur_expanded_dist; if (!use_disk_index_pq) { - cur_expanded_dist = dist_cmp->compare(query, node_fp_coords_copy, - (unsigned) aligned_dim); + cur_expanded_dist = dist_cmp(query, node_fp_coords_copy, + (size_t) aligned_dim); } else { if (metric == diskann::Metric::INNER_PRODUCT) cur_expanded_dist = disk_pq_table.inner_product( @@ -1204,7 +1203,7 @@ namespace diskann { auto location = (sector_scratch + i * SECTOR_LEN) + VECTOR_SECTOR_OFFSET(id); full_retset[i].distance = - dist_cmp->compare(query, (T *) location, this->data_dim); + dist_cmp(query, (T *) location, (size_t)this->data_dim); } std::sort(full_retset.begin(), full_retset.end(), diff --git a/thirdparty/DiskANN/src/utils.cpp b/thirdparty/DiskANN/src/utils.cpp index 3e643f8ab..294faf5cf 100644 --- a/thirdparty/DiskANN/src/utils.cpp +++ b/thirdparty/DiskANN/src/utils.cpp @@ -2,161 +2,9 @@ // Licensed under the MIT license. #include "utils.h" - #include -const uint32_t MAX_REQUEST_SIZE = 1024 * 1024 * 1024; // 64MB -const uint32_t MAX_SIMULTANEOUS_READ_REQUESTS = 128; - -#ifdef _WINDOWS -#include - -// Taken from: -// https://insufficientlycomplicated.wordpress.com/2011/11/07/detecting-intel-advanced-vector-extensions-avx-in-visual-studio/ -bool cpuHasAvxSupport() { - bool avxSupported = false; - - // Checking for AVX requires 3 things: - // 1) CPUID indicates that the OS uses XSAVE and XRSTORE - // instructions (allowing saving YMM registers on context - // switch) - // 2) CPUID indicates support for AVX - // 3) XGETBV indicates the AVX registers will be saved and - // restored on context switch - // - // Note that XGETBV is only available on 686 or later CPUs, so - // the instruction needs to be conditionally run. - int cpuInfo[4]; - __cpuid(cpuInfo, 1); - - bool osUsesXSAVE_XRSTORE = cpuInfo[2] & (1 << 27) || false; - bool cpuAVXSuport = cpuInfo[2] & (1 << 28) || false; - - if (osUsesXSAVE_XRSTORE && cpuAVXSuport) { - // Check if the OS will save the YMM registers - unsigned long long xcrFeatureMask = _xgetbv(_XCR_XFEATURE_ENABLED_MASK); - avxSupported = (xcrFeatureMask & 0x6) || false; - } - - return avxSupported; -} - -bool cpuHasAvx2Support() { - int cpuInfo[4]; - __cpuid(cpuInfo, 0); - int n = cpuInfo[0]; - if (n >= 7) { - __cpuidex(cpuInfo, 7, 0); - static int avx2Mask = 0x20; - return (cpuInfo[1] & avx2Mask) > 0; - } - return false; -} -#endif - -#ifdef _WINDOWS -bool AvxSupportedCPU = cpuHasAvxSupport(); -bool Avx2SupportedCPU = cpuHasAvx2Support(); -#else -bool Avx2SupportedCPU = true; -bool AvxSupportedCPU = false; -#endif - namespace diskann { - // Get the right distance function for the given metric. - template<> - diskann::Distance* get_distance_function(diskann::Metric m) { - if (m == diskann::Metric::L2) { - if (Avx2SupportedCPU) { - LOG(INFO) << "L2: Using AVX2 distance computation DistanceL2Float"; - return new diskann::DistanceL2Float(); - } else if (AvxSupportedCPU) { - LOG(WARNING) << "L2: AVX2 not supported. Using AVX distance computation"; - return new diskann::AVXDistanceL2Float(); - } else { - LOG(WARNING) << "L2: Older CPU. Using slow distance computation"; - return new diskann::SlowDistanceL2Float(); - } - } else if (m == diskann::Metric::COSINE) { - LOG(INFO) << "Cosine: Using either AVX or AVX2 implementation"; - return new diskann::DistanceCosineFloat(); - } else if (m == diskann::Metric::INNER_PRODUCT) { - LOG(INFO) << "Inner product: Using AVX2 implementation AVXDistanceInnerProductFloat"; - return new diskann::AVXDistanceInnerProductFloat(); - } else if (m == diskann::Metric::FAST_L2) { - LOG(INFO) << "Fast_L2: Using AVX2 implementation with norm memoization DistanceFastL2"; - return new diskann::DistanceFastL2(); - } else { - std::stringstream stream; - stream << "Only L2, cosine, and inner product supported for floating " - "point vectors as of now. Email " - "{gopalsr, harshasi, rakri}@microsoft.com if you need support " - "for any other metric." - << std::endl; - LOG(ERROR) << stream.str(); - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } - } - - template<> - diskann::Distance* get_distance_function(diskann::Metric m) { - if (m == diskann::Metric::L2) { - if (Avx2SupportedCPU) { - LOG(INFO) << "Using AVX2 distance computation DistanceL2Int8."; - return new diskann::DistanceL2Int8(); - } else if (AvxSupportedCPU) { - LOG(WARNING) << "AVX2 not supported. Using AVX distance computation"; - return new diskann::AVXDistanceL2Int8(); - } else { - LOG(WARNING) << "Older CPU. Using slow distance computation SlowDistanceL2Int."; - return new diskann::SlowDistanceL2Int(); - } - } else if (m == diskann::Metric::COSINE) { - LOG(INFO) << "Using either AVX or AVX2 for Cosine similarity DistanceCosineInt8."; - return new diskann::DistanceCosineInt8(); - } else { - std::stringstream stream; - stream << "Only L2 and cosine supported for signed byte vectors as of " - "now. Email " - "{gopalsr, harshasi, rakri}@microsoft.com if you need support " - "for any other metric." - << std::endl; - LOG(ERROR) << stream.str(); - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } - } - - template<> - diskann::Distance* get_distance_function(diskann::Metric m) { - if (m == diskann::Metric::L2) { -#ifdef _WINDOWS - LOG(WARNING) - << "WARNING: AVX/AVX2 distance function not defined for Uint8. Using " - "slow version. " - "Contact gopalsr@microsoft.com if you need AVX/AVX2 support." - << std::endl; -#endif - return new diskann::DistanceL2UInt8(); - } else if (m == diskann::Metric::COSINE) { - LOG(WARNING) << "AVX/AVX2 distance function not defined for Uint8. Using " - "slow version SlowDistanceCosineUint8() " - "Contact gopalsr@microsoft.com if you need AVX/AVX2 support."; - return new diskann::SlowDistanceCosineUInt8(); - } else { - std::stringstream stream; - stream << "Only L2 and cosine supported for unsigned byte vectors as of " - "now. Email " - "{gopalsr, harshasi, rakri}@microsoft.com if you need support " - "for any other metric." - << std::endl; - LOG(ERROR) << stream.str(); - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } - } - void block_convert(std::ofstream& writr, std::ifstream& readr, float* read_buf, _u64 npts, _u64 ndims) { readr.read((char*) read_buf, npts * ndims * sizeof(float)); @@ -205,5 +53,4 @@ namespace diskann { LOG(DEBUG) << "Wrote normalized points to file: " << outFileName; } - } // namespace diskann \ No newline at end of file diff --git a/unittest/AsyncIndex.h b/unittest/AsyncIndex.h index c55d2b71c..34a8e85d7 100644 --- a/unittest/AsyncIndex.h +++ b/unittest/AsyncIndex.h @@ -60,12 +60,6 @@ class AsyncIndex : public VecIndex { if (type == "diskann_f") { index_ = std::make_unique>(index_prefix, metric_type, std::make_shared()); - } else if (type == "disann_ui8") { - index_ = std::make_unique>(index_prefix, metric_type, - std::make_shared()); - } else if (type == "diskann_i8") { - index_ = std::make_unique>(index_prefix, metric_type, - std::make_shared()); } else { KNOWHERE_THROW_FORMAT("Invalid index type %s", std::string(type).c_str()); }